Commit 741d7685 authored by Matthew Holt's avatar Matthew Holt

Merge branch 'master' into fastcgi-methods

# Conflicts:
#	middleware/fastcgi/fastcgi.go
parents 600ee9a8 88e3a26c
language: go language: go
go: go:
- 1.4.3 - 1.6
- 1.5.2
- tip - tip
env:
- CGO_ENABLED=0
install: install:
- go get -d ./... - go get -t ./...
- go get golang.org/x/tools/cmd/vet - go get golang.org/x/tools/cmd/vet
script: script:
......
## Contributing to Caddy ## Contributing to Caddy
**[Join our dev chat on Gitter](https://gitter.im/mholt/caddy)** to chat with Welcome! Our community focuses on helping others and making Caddy the best it
other Caddy developers! (Dev chat only; try our can be. We gladly accept contributions and encourage you to get involved!
[support room](https://gitter.im/caddyserver/support) for help or
[general](https://gitter.im/caddyserver/general) for anything else.)
This project gladly accepts contributions and we encourage interested users to
get involved!
### Join us in chat
#### For small tweaks, bug fixes, and tests Please direct your discussion to the correct room:
Submit [pull requests](https://github.com/mholt/caddy/pulls) at any time. - **Dev Chat:** [gitter.im/mholt/caddy](https://gitter.im/mholt/caddy) - to chat
Bug fixes should be under test to assert correct behavior. Thank you for with other Caddy developers
helping out in simple ways! - **Support:**
[gitter.im/caddyserver/support](https://gitter.im/caddyserver/support) - to give
and get help
- **General:**
[gitter.im/caddyserver/general](https://gitter.im/caddyserver/general) - for
anything about Web development
#### Ideas, questions, bug reports ### Bug reports
Feel free to [open an issue](https://github.com/mholt/caddy/issues) with your First, please [search this repository](https://github.com/mholt/caddy/search?q=&type=Issues&utf8=%E2%9C%93)
ideas, questions, and bug reports, if one does not already exist for it. Bug with a variety of keywords to ensure your bug is not already reported.
reports should state expected behavior and contain clear instructions for
isolating and reproducing the problem.
See [How to Report Bugs Effectively](http://www.chiark.greenend.org.uk/~sgtatham/bugs.html).
If not, [open an issue](https://github.com/mholt/caddy/issues) and answer the
questions so we can understand and reproduce the problematic behavior.
#### New features The burden is on you to convince us that it is actually a bug in Caddy. This is
easiest to do when you write clear, concise instructions so we can reproduce
the behavior (even if it seems obvious). The more detailed and specific you are,
the faster we will be able to help you. Check out
[How to Report Bugs Effectively](http://www.chiark.greenend.org.uk/~sgtatham/bugs.html).
Before submitting a pull request, please open an issue first to discuss it and Please be kind. :smile: Remember that Caddy comes at no cost to you, and you're
claim it. This prevents overlapping efforts and keeps the project in-line with getting free help. If we helped you, please consider
its goals. If you prefer to discuss the feature privately, you can reach other [donating](https://caddyserver.com/donate) - it keeps us motivated!
developers on Gitter or you may email me directly. (My email address is below.)
And don't forget to write tests for new features!
### Minor improvements and new tests
#### Vulnerabilities Submit [pull requests](https://github.com/mholt/caddy/pulls) at any time. Make
sure to write tests to assert your change is working properly and is thoroughly
covered.
### Proposals, suggestions, ideas, new features
First, please [search](https://github.com/mholt/caddy/search?q=&type=Issues&utf8=%E2%9C%93)
with a variety of keywords to ensure your suggestion/proposal is new.
If so, you may open either an issue or a pull request for discussion and
feedback.
The advantage of issues is that you don't have to spend time actually
implementing your idea, but you should still describe it thoroughly. The
advantage of a pull request is that we can immediately see the impact the change
will have on the project, what the code will look like, and how to improve it.
The disadvantage of pull requests is that they are unlikely to get accepted
without significant changes, or it may be rejected entirely. Don't worry, that
won't happen without an open discussion first.
If you are going to spend significant time implementing code for a pull request,
best to open an issue first and "claim" it and get feedback before you invest
a lot of time.
### Vulnerabilities
If you've found a vulnerability that is serious, please email me: Matthew dot If you've found a vulnerability that is serious, please email me: Matthew dot
Holt at Gmail. If it's not a big deal, a pull request will probably be faster. Holt at Gmail. If it's not a big deal, a pull request will probably be faster.
...@@ -43,4 +73,5 @@ Holt at Gmail. If it's not a big deal, a pull request will probably be faster. ...@@ -43,4 +73,5 @@ Holt at Gmail. If it's not a big deal, a pull request will probably be faster.
## Thank you ## Thank you
Thanks for your help! Caddy would not be what it is today without your contributions. Thanks for your help! Caddy would not be what it is today without your
contributions.
*If you are filing a bug report, please answer these questions. If your issue is not a bug report, you do not need to use this template. Either way, please consider donating if we've helped you. Thanks!*
#### 1. What version of Caddy are you running (`caddy -version`)?
#### 2. What are you trying to do?
#### 3. What is your entire Caddyfile?
```text
(Put Caddyfile here)
```
#### 4. How did you run Caddy (give the full command and describe the execution environment)?
#### 5. What did you expect to see?
#### 6. What did you see instead (give full error messages and/or log)?
...@@ -96,7 +96,7 @@ You may also be interested in the [developer guide] ...@@ -96,7 +96,7 @@ You may also be interested in the [developer guide]
## Running from Source ## Running from Source
Note: You will need **[Go 1.4](https://golang.org/dl/)** or a later version. Note: You will need **[Go 1.6](https://golang.org/dl/)** or newer.
1. `$ go get github.com/mholt/caddy` 1. `$ go get github.com/mholt/caddy`
2. `cd` into your website's directory 2. `cd` into your website's directory
......
...@@ -6,14 +6,21 @@ clone_folder: c:\gopath\src\github.com\mholt\caddy ...@@ -6,14 +6,21 @@ clone_folder: c:\gopath\src\github.com\mholt\caddy
environment: environment:
GOPATH: c:\gopath GOPATH: c:\gopath
CGO_ENABLED: 0
install: install:
- go get golang.org/x/tools/cmd/vet - rmdir c:\go /s /q
- echo %GOPATH% - appveyor DownloadFile https://storage.googleapis.com/golang/go1.6.windows-amd64.zip
- 7z x go1.6.windows-amd64.zip -y -oC:\ > NUL
- go version - go version
- go env - go env
- go get -d ./... - go get golang.org/x/tools/cmd/vet
- go get -t ./...
build_script: build: off
test_script:
- go vet ./... - go vet ./...
- go test ./... - go test ./...
\ No newline at end of file
deploy: off
#!/usr/bin/env bash
#
# Caddy build script. Automates proper versioning.
#
# Usage:
#
# $ ./build.bash [output_filename]
#
# Outputs compiled program in current directory.
# Default file name is 'ecaddy'.
#
set -e
output="$1"
if [ -z "$output" ]; then
output="ecaddy"
fi
pkg=main
# Timestamp of build
builddate_id=$pkg.buildDate
builddate=`date -u`
# Current tag, if HEAD is on a tag
tag_id=$pkg.gitTag
set +e
tag=`git describe --exact-match HEAD 2> /dev/null`
set -e
# Nearest tag on branch
lasttag_id=$pkg.gitNearestTag
lasttag=`git describe --abbrev=0 --tags HEAD`
# Commit SHA
commit_id=$pkg.gitCommit
commit=`git rev-parse --short HEAD`
# Summary of uncommited changes
shortstat_id=$pkg.gitShortStat
shortstat=`git diff-index --shortstat HEAD`
# List of modified files
files_id=$pkg.gitFilesModified
files=`git diff-index --name-only HEAD`
go build -ldflags "
-X \"$builddate_id=$builddate\"
-X \"$tag_id=$tag\"
-X \"$lasttag_id=$lasttag\"
-X \"$commit_id=$commit\"
-X \"$shortstat_id=$shortstat\"
-X \"$files_id=$files\"
" -o "$output"
...@@ -26,9 +26,10 @@ import ( ...@@ -26,9 +26,10 @@ import (
"path" "path"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/mholt/caddy/caddy/letsencrypt" "github.com/mholt/caddy/caddy/https"
"github.com/mholt/caddy/server" "github.com/mholt/caddy/server"
) )
...@@ -44,7 +45,7 @@ var ( ...@@ -44,7 +45,7 @@ var (
Quiet bool Quiet bool
// HTTP2 indicates whether HTTP2 is enabled or not. // HTTP2 indicates whether HTTP2 is enabled or not.
HTTP2 bool // TODO: temporary flag until http2 is standard HTTP2 bool
// PidFile is the path to the pidfile to create. // PidFile is the path to the pidfile to create.
PidFile string PidFile string
...@@ -191,8 +192,13 @@ func startServers(groupings bindingGroup) error { ...@@ -191,8 +192,13 @@ func startServers(groupings bindingGroup) error {
if err != nil { if err != nil {
return err return err
} }
s.HTTP2 = HTTP2 // TODO: This setting is temporary s.HTTP2 = HTTP2
s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running s.ReqCallback = https.RequestCallback // ensures we can solve ACME challenges while running
if s.OnDemandTLS {
s.TLSConfig.GetCertificate = https.GetOrObtainCertificate // TLS on demand -- awesome!
} else {
s.TLSConfig.GetCertificate = https.GetCertificate
}
var ln server.ListenerFile var ln server.ListenerFile
if IsRestart() { if IsRestart() {
...@@ -277,7 +283,7 @@ func startServers(groupings bindingGroup) error { ...@@ -277,7 +283,7 @@ func startServers(groupings bindingGroup) error {
// It does NOT execute shutdown callbacks that may have been // It does NOT execute shutdown callbacks that may have been
// configured by middleware (they must be executed separately). // configured by middleware (they must be executed separately).
func Stop() error { func Stop() error {
letsencrypt.Deactivate() https.Deactivate()
serversMu.Lock() serversMu.Lock()
for _, s := range servers { for _, s := range servers {
...@@ -312,6 +318,7 @@ func LoadCaddyfile(loader func() (Input, error)) (cdyfile Input, err error) { ...@@ -312,6 +318,7 @@ func LoadCaddyfile(loader func() (Input, error)) (cdyfile Input, err error) {
return nil, err return nil, err
} }
cdyfile = loadedGob.Caddyfile cdyfile = loadedGob.Caddyfile
atomic.StoreInt32(https.OnDemandIssuedCount, loadedGob.OnDemandTLSCertsIssued)
} }
// Try user's loader // Try user's loader
......
...@@ -7,7 +7,7 @@ import ( ...@@ -7,7 +7,7 @@ import (
) )
func TestCaddyStartStop(t *testing.T) { func TestCaddyStartStop(t *testing.T) {
caddyfile := "localhost:1984\ntls off" caddyfile := "localhost:1984"
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
err := Start(CaddyfileInput{Contents: []byte(caddyfile)}) err := Start(CaddyfileInput{Contents: []byte(caddyfile)})
......
...@@ -8,10 +8,9 @@ import ( ...@@ -8,10 +8,9 @@ import (
"net" "net"
"sync" "sync"
"github.com/mholt/caddy/caddy/letsencrypt" "github.com/mholt/caddy/caddy/https"
"github.com/mholt/caddy/caddy/parse" "github.com/mholt/caddy/caddy/parse"
"github.com/mholt/caddy/caddy/setup" "github.com/mholt/caddy/caddy/setup"
"github.com/mholt/caddy/middleware"
"github.com/mholt/caddy/server" "github.com/mholt/caddy/server"
) )
...@@ -55,7 +54,6 @@ func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Con ...@@ -55,7 +54,6 @@ func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Con
Port: addr.Port, Port: addr.Port,
Scheme: addr.Scheme, Scheme: addr.Scheme,
Root: Root, Root: Root,
Middleware: make(map[string][]middleware.Middleware),
ConfigFile: filename, ConfigFile: filename,
AppName: AppName, AppName: AppName,
AppVersion: AppVersion, AppVersion: AppVersion,
...@@ -89,8 +87,7 @@ func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Con ...@@ -89,8 +87,7 @@ func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Con
return nil, nil, lastDirectiveIndex, err return nil, nil, lastDirectiveIndex, err
} }
if midware != nil { if midware != nil {
// TODO: For now, we only support the default path scope / config.Middleware = append(config.Middleware, midware)
config.Middleware["/"] = append(config.Middleware["/"], midware)
} }
storages[dir.name] = controller.ServerBlockStorage // persist for this server block storages[dir.name] = controller.ServerBlockStorage // persist for this server block
} }
...@@ -128,7 +125,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { ...@@ -128,7 +125,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) {
if !IsRestart() && !Quiet { if !IsRestart() && !Quiet {
fmt.Print("Activating privacy features...") fmt.Print("Activating privacy features...")
} }
configs, err = letsencrypt.Activate(configs) configs, err = https.Activate(configs)
if err != nil { if err != nil {
return nil, err return nil, err
} else if !IsRestart() && !Quiet { } else if !IsRestart() && !Quiet {
...@@ -171,8 +168,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { ...@@ -171,8 +168,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) {
return nil, err return nil, err
} }
if midware != nil { if midware != nil {
// TODO: For now, we only support the default path scope / configs[configIndex].Middleware = append(configs[configIndex].Middleware, midware)
configs[configIndex].Middleware["/"] = append(configs[configIndex].Middleware["/"], midware)
} }
storages[dir.name] = controller.ServerBlockStorage // persist for this server block storages[dir.name] = controller.ServerBlockStorage // persist for this server block
} }
...@@ -318,7 +314,7 @@ func validDirective(d string) bool { ...@@ -318,7 +314,7 @@ func validDirective(d string) bool {
// root. // root.
func DefaultInput() CaddyfileInput { func DefaultInput() CaddyfileInput {
port := Port port := Port
if letsencrypt.HostQualifies(Host) && port == DefaultPort { if https.HostQualifies(Host) && port == DefaultPort {
port = "443" port = "443"
} }
return CaddyfileInput{ return CaddyfileInput{
......
package caddy package caddy
import ( import (
"github.com/mholt/caddy/caddy/https"
"github.com/mholt/caddy/caddy/parse" "github.com/mholt/caddy/caddy/parse"
"github.com/mholt/caddy/caddy/setup" "github.com/mholt/caddy/caddy/setup"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
...@@ -43,7 +44,7 @@ var directiveOrder = []directive{ ...@@ -43,7 +44,7 @@ var directiveOrder = []directive{
// Essential directives that initialize vital configuration settings // Essential directives that initialize vital configuration settings
{"root", setup.Root}, {"root", setup.Root},
{"bind", setup.BindHost}, {"bind", setup.BindHost},
{"tls", setup.TLS}, // letsencrypt is set up just after tls {"tls", https.Setup},
// Other directives that don't create HTTP handlers // Other directives that don't create HTTP handlers
{"startup", setup.Startup}, {"startup", setup.Startup},
...@@ -68,6 +69,23 @@ var directiveOrder = []directive{ ...@@ -68,6 +69,23 @@ var directiveOrder = []directive{
{"browse", setup.Browse}, {"browse", setup.Browse},
} }
// RegisterDirective adds the given directive to caddy's list of directives.
// Pass the name of a directive you want it to be placed after,
// otherwise it will be placed at the bottom of the stack.
func RegisterDirective(name string, setup SetupFunc, after string) {
dir := directive{name: name, setup: setup}
idx := len(directiveOrder)
for i := range directiveOrder {
if directiveOrder[i].name == after {
idx = i + 1
break
}
}
newDirectives := append(directiveOrder[:idx], append([]directive{dir}, directiveOrder[idx:]...)...)
directiveOrder = newDirectives
parse.ValidDirectives[name] = struct{}{}
}
// directive ties together a directive name with its setup function. // directive ties together a directive name with its setup function.
type directive struct { type directive struct {
name string name string
......
package caddy
import (
"reflect"
"testing"
)
func TestRegister(t *testing.T) {
directives := []directive{
{"dummy", nil},
{"dummy2", nil},
}
directiveOrder = directives
RegisterDirective("foo", nil, "dummy")
if len(directiveOrder) != 3 {
t.Fatal("Should have 3 directives now")
}
getNames := func() (s []string) {
for _, d := range directiveOrder {
s = append(s, d.name)
}
return s
}
if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2"}) {
t.Fatalf("directive order doesn't match: %s", getNames())
}
RegisterDirective("bar", nil, "ASDASD")
if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2", "bar"}) {
t.Fatalf("directive order doesn't match: %s", getNames())
}
}
...@@ -11,14 +11,8 @@ import ( ...@@ -11,14 +11,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"github.com/mholt/caddy/caddy/letsencrypt"
) )
func init() {
letsencrypt.OnChange = func() error { return Restart(nil) }
}
// isLocalhost returns true if host looks explicitly like a localhost address. // isLocalhost returns true if host looks explicitly like a localhost address.
func isLocalhost(host string) bool { func isLocalhost(host string) bool {
return host == "localhost" || host == "::1" || strings.HasPrefix(host, "127.") return host == "localhost" || host == "::1" || strings.HasPrefix(host, "127.")
...@@ -69,10 +63,12 @@ var signalParentOnce sync.Once ...@@ -69,10 +63,12 @@ var signalParentOnce sync.Once
// caddyfileGob maps bind address to index of the file descriptor // caddyfileGob maps bind address to index of the file descriptor
// in the Files array passed to the child process. It also contains // in the Files array passed to the child process. It also contains
// the caddyfile contents. Used only during graceful restarts. // the caddyfile contents and other state needed by the new process.
// Used only during graceful restarts where a new process is spawned.
type caddyfileGob struct { type caddyfileGob struct {
ListenerFds map[string]uintptr ListenerFds map[string]uintptr
Caddyfile Input Caddyfile Input
OnDemandTLSCertsIssued int32
} }
// IsRestart returns whether this process is, according // IsRestart returns whether this process is, according
......
package https
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"log"
"strings"
"sync"
"time"
"github.com/xenolf/lego/acme"
"golang.org/x/crypto/ocsp"
)
// certCache stores certificates in memory,
// keying certificates by name.
var certCache = make(map[string]Certificate)
var certCacheMu sync.RWMutex
// Certificate is a tls.Certificate with associated metadata tacked on.
// Even if the metadata can be obtained by parsing the certificate,
// we can be more efficient by extracting the metadata once so it's
// just there, ready to use.
type Certificate struct {
tls.Certificate
// Names is the list of names this certificate is written for.
// The first is the CommonName (if any), the rest are SAN.
Names []string
// NotAfter is when the certificate expires.
NotAfter time.Time
// Managed certificates are certificates that Caddy is managing,
// as opposed to the user specifying a certificate and key file
// or directory and managing the certificate resources themselves.
Managed bool
// OnDemand certificates are obtained or loaded on-demand during TLS
// handshakes (as opposed to preloaded certificates, which are loaded
// at startup). If OnDemand is true, Managed must necessarily be true.
// OnDemand certificates are maintained in the background just like
// preloaded ones, however, if an OnDemand certificate fails to renew,
// it is removed from the in-memory cache.
OnDemand bool
// OCSP contains the certificate's parsed OCSP response.
OCSP *ocsp.Response
}
// getCertificate gets a certificate that matches name (a server name)
// from the in-memory cache. If there is no exact match for name, it
// will be checked against names of the form '*.example.com' (wildcard
// certificates) according to RFC 6125. If a match is found, matched will
// be true. If no matches are found, matched will be false and a default
// certificate will be returned with defaulted set to true. If no default
// certificate is set, defaulted will be set to false.
//
// The logic in this function is adapted from the Go standard library,
// which is by the Go Authors.
//
// This function is safe for concurrent use.
func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
var ok bool
// Not going to trim trailing dots here since RFC 3546 says,
// "The hostname is represented ... without a trailing dot."
// Just normalize to lowercase.
name = strings.ToLower(name)
certCacheMu.RLock()
defer certCacheMu.RUnlock()
// exact match? great, let's use it
if cert, ok = certCache[name]; ok {
matched = true
return
}
// try replacing labels in the name with wildcards until we get a match
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok = certCache[candidate]; ok {
matched = true
return
}
}
// if nothing matches, use the default certificate or bust
cert, defaulted = certCache[""]
return
}
// cacheManagedCertificate loads the certificate for domain into the
// cache, flagging it as Managed and, if onDemand is true, as OnDemand
// (meaning that it was obtained or loaded during a TLS handshake).
//
// This function is safe for concurrent use.
func cacheManagedCertificate(domain string, onDemand bool) (Certificate, error) {
cert, err := makeCertificateFromDisk(storage.SiteCertFile(domain), storage.SiteKeyFile(domain))
if err != nil {
return cert, err
}
cert.Managed = true
cert.OnDemand = onDemand
cacheCertificate(cert)
return cert, nil
}
// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile
// and keyFile, which must be in PEM format. It stores the certificate in
// memory. The Managed and OnDemand flags of the certificate will be set to
// false.
//
// This function is safe for concurrent use.
func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
cert, err := makeCertificateFromDisk(certFile, keyFile)
if err != nil {
return err
}
cacheCertificate(cert)
return nil
}
// cacheUnmanagedCertificatePEMBytes makes a certificate out of the PEM bytes
// of the certificate and key, then caches it in memory.
//
// This function is safe for concurrent use.
func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
cert, err := makeCertificate(certBytes, keyBytes)
if err != nil {
return err
}
cacheCertificate(cert)
return nil
}
// makeCertificateFromDisk makes a Certificate by loading the
// certificate and key files. It fills out all the fields in
// the certificate except for the Managed and OnDemand flags.
// (It is up to the caller to set those.)
func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
certPEMBlock, err := ioutil.ReadFile(certFile)
if err != nil {
return Certificate{}, err
}
keyPEMBlock, err := ioutil.ReadFile(keyFile)
if err != nil {
return Certificate{}, err
}
return makeCertificate(certPEMBlock, keyPEMBlock)
}
// makeCertificate turns a certificate PEM bundle and a key PEM block into
// a Certificate, with OCSP and other relevant metadata tagged with it,
// except for the OnDemand and Managed flags. It is up to the caller to
// set those properties.
func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
var cert Certificate
// Convert to a tls.Certificate
tlsCert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return cert, err
}
if len(tlsCert.Certificate) == 0 {
return cert, errors.New("certificate is empty")
}
// Parse leaf certificate and extract relevant metadata
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return cert, err
}
if leaf.Subject.CommonName != "" {
cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)}
}
for _, name := range leaf.DNSNames {
if name != leaf.Subject.CommonName {
cert.Names = append(cert.Names, strings.ToLower(name))
}
}
cert.NotAfter = leaf.NotAfter
// Staple OCSP
ocspBytes, ocspResp, err := acme.GetOCSPForCert(certPEMBlock)
if err != nil {
// An error here is not a problem because a certificate may simply
// not contain a link to an OCSP server. But we should log it anyway.
log.Printf("[WARNING] No OCSP stapling for %v: %v", cert.Names, err)
} else if ocspResp.Status == ocsp.Good {
tlsCert.OCSPStaple = ocspBytes
cert.OCSP = ocspResp
}
cert.Certificate = tlsCert
return cert, nil
}
// cacheCertificate adds cert to the in-memory cache. If the cache is
// empty, cert will be used as the default certificate. If the cache is
// full, random entries are deleted until there is room to map all the
// names on the certificate.
//
// This certificate will be keyed to the names in cert.Names. Any name
// that is already a key in the cache will be replaced with this cert.
//
// This function is safe for concurrent use.
func cacheCertificate(cert Certificate) {
certCacheMu.Lock()
if _, ok := certCache[""]; !ok {
// use as default
cert.Names = append(cert.Names, "")
certCache[""] = cert
}
for len(certCache)+len(cert.Names) > 10000 {
// for simplicity, just remove random elements
for key := range certCache {
if key == "" { // ... but not the default cert
continue
}
delete(certCache, key)
break
}
}
for _, name := range cert.Names {
certCache[name] = cert
}
certCacheMu.Unlock()
}
package https
import "testing"
func TestUnexportedGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
// When cache is empty
if _, matched, defaulted := getCertificate("example.com"); matched || defaulted {
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
}
// When cache has one certificate in it (also is default)
defaultCert := Certificate{Names: []string{"example.com", ""}}
certCache[""] = defaultCert
certCache["example.com"] = defaultCert
if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}}
if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When no certificate matches, the default is returned
if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted {
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
} else if cert.Names[0] != "example.com" {
t.Errorf("Expected default cert, got: %v", cert)
}
}
func TestCacheCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}})
if _, ok := certCache["example.com"]; !ok {
t.Error("Expected first cert to be cached by key 'example.com', but it wasn't")
}
if _, ok := certCache["sub.example.com"]; !ok {
t.Error("Expected first cert to be cached by key 'sub.exmaple.com', but it wasn't")
}
if cert, ok := certCache[""]; !ok || cert.Names[2] != "" {
t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't")
}
cacheCertificate(Certificate{Names: []string{"example2.com"}})
if _, ok := certCache["example2.com"]; !ok {
t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't")
}
if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" {
t.Error("Expected second cert to NOT be cached as default, but it was")
}
}
package https
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"sync"
"time"
"github.com/mholt/caddy/server"
"github.com/xenolf/lego/acme"
)
// acmeMu ensures that only one ACME challenge occurs at a time.
var acmeMu sync.Mutex
// ACMEClient is an acme.Client with custom state attached.
type ACMEClient struct {
*acme.Client
AllowPrompts bool // if false, we assume AlternatePort must be used
}
// NewACMEClient creates a new ACMEClient given an email and whether
// prompting the user is allowed. Clients should not be kept and
// re-used over long periods of time, but immediate re-use is more
// efficient than re-creating on every iteration.
var NewACMEClient = func(email string, allowPrompts bool) (*ACMEClient, error) {
// Look up or create the LE user account
leUser, err := getUser(email)
if err != nil {
return nil, err
}
// The client facilitates our communication with the CA server.
client, err := acme.NewClient(CAUrl, &leUser, KeyType)
if err != nil {
return nil, err
}
// If not registered, the user must register an account with the CA
// and agree to terms
if leUser.Registration == nil {
reg, err := client.Register()
if err != nil {
return nil, errors.New("registration error: " + err.Error())
}
leUser.Registration = reg
if allowPrompts { // can't prompt a user who isn't there
if !Agreed && reg.TosURL == "" {
Agreed = promptUserAgreement(saURL, false) // TODO - latest URL
}
if !Agreed && reg.TosURL == "" {
return nil, errors.New("user must agree to terms")
}
}
err = client.AgreeToTOS()
if err != nil {
saveUser(leUser) // Might as well try, right?
return nil, errors.New("error agreeing to terms: " + err.Error())
}
// save user to the file system
err = saveUser(leUser)
if err != nil {
return nil, errors.New("could not save user: " + err.Error())
}
}
return &ACMEClient{
Client: client,
AllowPrompts: allowPrompts,
}, nil
}
// NewACMEClientGetEmail creates a new ACMEClient and gets an email
// address at the same time (a server config is required, since it
// may contain an email address in it).
func NewACMEClientGetEmail(config server.Config, allowPrompts bool) (*ACMEClient, error) {
return NewACMEClient(getEmail(config, allowPrompts), allowPrompts)
}
// Configure configures c according to bindHost, which is the host (not
// whole address) to bind the listener to in solving the http and tls-sni
// challenges.
func (c *ACMEClient) Configure(bindHost string) {
// If we allow prompts, operator must be present. In our case,
// that is synonymous with saying the server is not already
// started. So if the user is still there, we don't use
// AlternatePort because we don't need to proxy the challenges.
// Conversely, if the operator is not there, the server has
// already started and we need to proxy the challenge.
if c.AllowPrompts {
// Operator is present; server is not already listening
c.SetHTTPAddress(net.JoinHostPort(bindHost, ""))
c.SetTLSAddress(net.JoinHostPort(bindHost, ""))
//c.ExcludeChallenges([]acme.Challenge{acme.DNS01})
} else {
// Operator is not present; server is started, so proxy challenges
c.SetHTTPAddress(net.JoinHostPort(bindHost, AlternatePort))
c.SetTLSAddress(net.JoinHostPort(bindHost, AlternatePort))
//c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01})
}
c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) // TODO: can we proxy TLS challenges? and we should support DNS...
}
// Obtain obtains a single certificate for names. It stores the certificate
// on the disk if successful.
func (c *ACMEClient) Obtain(names []string) error {
Attempts:
for attempts := 0; attempts < 2; attempts++ {
acmeMu.Lock()
certificate, failures := c.ObtainCertificate(names, true, nil)
acmeMu.Unlock()
if len(failures) > 0 {
// Error - try to fix it or report it to the user and abort
var errMsg string // we'll combine all the failures into a single error message
var promptedForAgreement bool // only prompt user for agreement at most once
for errDomain, obtainErr := range failures {
// TODO: Double-check, will obtainErr ever be nil?
if tosErr, ok := obtainErr.(acme.TOSError); ok {
// Terms of Service agreement error; we can probably deal with this
if !Agreed && !promptedForAgreement && c.AllowPrompts {
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
promptedForAgreement = true
}
if Agreed || !c.AllowPrompts {
err := c.AgreeToTOS()
if err != nil {
return errors.New("error agreeing to updated terms: " + err.Error())
}
continue Attempts
}
}
// If user did not agree or it was any other kind of error, just append to the list of errors
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
}
return errors.New(errMsg)
}
// Success - immediately save the certificate resource
err := saveCertResource(certificate)
if err != nil {
return fmt.Errorf("error saving assets for %v: %v", names, err)
}
break
}
return nil
}
// Renew renews the managed certificate for name. Right now our storage
// mechanism only supports one name per certificate, so this function only
// accepts one domain as input. It can be easily modified to support SAN
// certificates if, one day, they become desperately needed enough that our
// storage mechanism is upgraded to be more complex to support SAN certs.
//
// Anyway, this function is safe for concurrent use.
func (c *ACMEClient) Renew(name string) error {
// Prepare for renewal (load PEM cert, key, and meta)
certBytes, err := ioutil.ReadFile(storage.SiteCertFile(name))
if err != nil {
return err
}
keyBytes, err := ioutil.ReadFile(storage.SiteKeyFile(name))
if err != nil {
return err
}
metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(name))
if err != nil {
return err
}
var certMeta acme.CertificateResource
err = json.Unmarshal(metaBytes, &certMeta)
certMeta.Certificate = certBytes
certMeta.PrivateKey = keyBytes
// Perform renewal and retry if necessary, but not too many times.
var newCertMeta acme.CertificateResource
var success bool
for attempts := 0; attempts < 2; attempts++ {
acmeMu.Lock()
newCertMeta, err = c.RenewCertificate(certMeta, true)
acmeMu.Unlock()
if err == nil {
success = true
break
}
// If the legal terms changed and need to be agreed to again,
// we can handle that.
if _, ok := err.(acme.TOSError); ok {
err := c.AgreeToTOS()
if err != nil {
return err
}
continue
}
// For any other kind of error, wait 10s and try again.
time.Sleep(10 * time.Second)
}
if !success {
return errors.New("too many renewal attempts; last error: " + err.Error())
}
return saveCertResource(newCertMeta)
}
package https
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io/ioutil"
"os"
)
// loadPrivateKey loads a PEM-encoded ECC/RSA private key from file.
func loadPrivateKey(file string) (crypto.PrivateKey, error) {
keyBytes, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
keyBlock, _ := pem.Decode(keyBytes)
switch keyBlock.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(keyBlock.Bytes)
}
return nil, errors.New("unknown private key type")
}
// savePrivateKey saves a PEM-encoded ECC/RSA private key to file.
func savePrivateKey(key crypto.PrivateKey, file string) error {
var pemType string
var keyBytes []byte
switch key := key.(type) {
case *ecdsa.PrivateKey:
var err error
pemType = "EC"
keyBytes, err = x509.MarshalECPrivateKey(key)
if err != nil {
return err
}
case *rsa.PrivateKey:
pemType = "RSA"
keyBytes = x509.MarshalPKCS1PrivateKey(key)
}
pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes}
keyOut, err := os.Create(file)
if err != nil {
return err
}
keyOut.Chmod(0600)
defer keyOut.Close()
return pem.Encode(keyOut, &pemKey)
}
package letsencrypt package https
import ( import (
"bytes" "bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
...@@ -10,21 +13,17 @@ import ( ...@@ -10,21 +13,17 @@ import (
"testing" "testing"
) )
func init() {
rsaKeySizeToUse = 128 // make tests faster; small key size OK for testing
}
func TestSaveAndLoadRSAPrivateKey(t *testing.T) { func TestSaveAndLoadRSAPrivateKey(t *testing.T) {
keyFile := "test.key" keyFile := "test.key"
defer os.Remove(keyFile) defer os.Remove(keyFile)
privateKey, err := rsa.GenerateKey(rand.Reader, rsaKeySizeToUse) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// test save // test save
err = saveRSAPrivateKey(privateKey, keyFile) err = savePrivateKey(privateKey, keyFile)
if err != nil { if err != nil {
t.Fatal("error saving private key:", err) t.Fatal("error saving private key:", err)
} }
...@@ -43,23 +42,70 @@ func TestSaveAndLoadRSAPrivateKey(t *testing.T) { ...@@ -43,23 +42,70 @@ func TestSaveAndLoadRSAPrivateKey(t *testing.T) {
} }
// test load // test load
loadedKey, err := loadRSAPrivateKey(keyFile) loadedKey, err := loadPrivateKey(keyFile)
if err != nil { if err != nil {
t.Error("error loading private key:", err) t.Error("error loading private key:", err)
} }
// verify loaded key is correct // verify loaded key is correct
if !rsaPrivateKeysSame(privateKey, loadedKey) { if !PrivateKeysSame(privateKey, loadedKey) {
t.Error("Expected key bytes to be the same, but they weren't") t.Error("Expected key bytes to be the same, but they weren't")
} }
} }
// rsaPrivateKeysSame compares the bytes of a and b and returns true if they are the same. func TestSaveAndLoadECCPrivateKey(t *testing.T) {
func rsaPrivateKeysSame(a, b *rsa.PrivateKey) bool { keyFile := "test.key"
return bytes.Equal(rsaPrivateKeyBytes(a), rsaPrivateKeyBytes(b)) defer os.Remove(keyFile)
privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
// test save
err = savePrivateKey(privateKey, keyFile)
if err != nil {
t.Fatal("error saving private key:", err)
}
// it doesn't make sense to test file permission on windows
if runtime.GOOS != "windows" {
// get info of the key file
info, err := os.Stat(keyFile)
if err != nil {
t.Fatal("error stating private key:", err)
}
// verify permission of key file is correct
if info.Mode().Perm() != 0600 {
t.Error("Expected key file to have permission 0600, but it wasn't")
}
}
// test load
loadedKey, err := loadPrivateKey(keyFile)
if err != nil {
t.Error("error loading private key:", err)
}
// verify loaded key is correct
if !PrivateKeysSame(privateKey, loadedKey) {
t.Error("Expected key bytes to be the same, but they weren't")
}
}
// PrivateKeysSame compares the bytes of a and b and returns true if they are the same.
func PrivateKeysSame(a, b crypto.PrivateKey) bool {
return bytes.Equal(PrivateKeyBytes(a), PrivateKeyBytes(b))
} }
// rsaPrivateKeyBytes returns the bytes of DER-encoded key. // PrivateKeyBytes returns the bytes of DER-encoded key.
func rsaPrivateKeyBytes(key *rsa.PrivateKey) []byte { func PrivateKeyBytes(key crypto.PrivateKey) []byte {
return x509.MarshalPKCS1PrivateKey(key) var keyBytes []byte
switch key := key.(type) {
case *rsa.PrivateKey:
keyBytes = x509.MarshalPKCS1PrivateKey(key)
case *ecdsa.PrivateKey:
keyBytes, _ = x509.MarshalECPrivateKey(key)
}
return keyBytes
} }
package letsencrypt package https
import ( import (
"crypto/tls" "crypto/tls"
"log" "log"
"net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
...@@ -23,21 +22,16 @@ func RequestCallback(w http.ResponseWriter, r *http.Request) bool { ...@@ -23,21 +22,16 @@ func RequestCallback(w http.ResponseWriter, r *http.Request) bool {
scheme = "https" scheme = "https"
} }
hostname, _, err := net.SplitHostPort(r.URL.Host) upstream, err := url.Parse(scheme + "://localhost:" + AlternatePort)
if err != nil {
hostname = r.URL.Host
}
upstream, err := url.Parse(scheme + "://" + hostname + ":" + AlternatePort)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
log.Printf("[ERROR] letsencrypt handler: %v", err) log.Printf("[ERROR] ACME proxy handler: %v", err)
return true return true
} }
proxy := httputil.NewSingleHostReverseProxy(upstream) proxy := httputil.NewSingleHostReverseProxy(upstream)
proxy.Transport = &http.Transport{ proxy.Transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // client would use self-signed cert TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs
} }
proxy.ServeHTTP(w, r) proxy.ServeHTTP(w, r)
......
package letsencrypt package https
import ( import (
"net" "net"
......
package https
import (
"bytes"
"crypto/tls"
"encoding/pem"
"errors"
"fmt"
"log"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mholt/caddy/server"
"github.com/xenolf/lego/acme"
)
// GetCertificate gets a certificate to satisfy clientHello as long as
// the certificate is already cached in memory. It will not be loaded
// from disk or obtained from the CA during the handshake.
//
// This function is safe for use as a tls.Config.GetCertificate callback.
func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertDuringHandshake(clientHello.ServerName, false, false)
return &cert.Certificate, err
}
// GetOrObtainCertificate will get a certificate to satisfy clientHello, even
// if that means obtaining a new certificate from a CA during the handshake.
// It first checks the in-memory cache, then accesses disk, then accesses the
// network if it must. An obtained certificate will be stored on disk and
// cached in memory.
//
// This function is safe for use as a tls.Config.GetCertificate callback.
func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertDuringHandshake(clientHello.ServerName, true, true)
return &cert.Certificate, err
}
// getCertDuringHandshake will get a certificate for name. It first tries
// the in-memory cache. If no certificate for name is in the cache and if
// loadIfNecessary == true, it goes to disk to load it into the cache and
// serve it. If it's not on disk and if obtainIfNecessary == true, the
// certificate will be obtained from the CA, cached, and served. If
// obtainIfNecessary is true, then loadIfNecessary must also be set to true.
// An error will be returned if and only if no certificate is available.
//
// This function is safe for concurrent use.
func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
// First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := getCertificate(name)
if matched {
return cert, nil
}
if loadIfNecessary {
// Then check to see if we have one on disk
loadedCert, err := cacheManagedCertificate(name, true)
if err == nil {
loadedCert, err = handshakeMaintenance(name, loadedCert)
if err != nil {
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
}
return loadedCert, nil
}
if obtainIfNecessary {
// By this point, we need to ask the CA for a certificate
name = strings.ToLower(name)
// Make sure aren't over any applicable limits
err := checkLimitsForObtainingNewCerts(name)
if err != nil {
return Certificate{}, err
}
// Name has to qualify for a certificate
if !HostQualifies(name) {
return cert, errors.New("hostname '" + name + "' does not qualify for certificate")
}
// Obtain certificate from the CA
return obtainOnDemandCertificate(name)
}
}
if defaulted {
return cert, nil
}
return Certificate{}, errors.New("no certificate for " + name)
}
// checkLimitsForObtainingNewCerts checks to see if name can be issued right
// now according to mitigating factors we keep track of and preferences the
// user has set. If a non-nil error is returned, do not issue a new certificate
// for name.
func checkLimitsForObtainingNewCerts(name string) error {
// User can set hard limit for number of certs for the process to issue
if onDemandMaxIssue > 0 && atomic.LoadInt32(OnDemandIssuedCount) >= onDemandMaxIssue {
return fmt.Errorf("%s: maximum certificates issued (%d)", name, onDemandMaxIssue)
}
// Make sure name hasn't failed a challenge recently
failedIssuanceMu.RLock()
when, ok := failedIssuance[name]
failedIssuanceMu.RUnlock()
if ok {
return fmt.Errorf("%s: throttled; refusing to issue cert since last attempt on %s failed", name, when.String())
}
// Make sure, if we've issued a few certificates already, that we haven't
// issued any recently
lastIssueTimeMu.Lock()
since := time.Since(lastIssueTime)
lastIssueTimeMu.Unlock()
if atomic.LoadInt32(OnDemandIssuedCount) >= 10 && since < 10*time.Minute {
return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since)
}
// 👍Good to go
return nil
}
// obtainOnDemandCertificate obtains a certificate for name for the given
// name. If another goroutine has already started obtaining a cert for
// name, it will wait and use what the other goroutine obtained.
//
// This function is safe for use by multiple concurrent goroutines.
func obtainOnDemandCertificate(name string) (Certificate, error) {
// We must protect this process from happening concurrently, so synchronize.
obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name]
if ok {
// lucky us -- another goroutine is already obtaining the certificate.
// wait for it to finish obtaining the cert and then we'll use it.
obtainCertWaitChansMu.Unlock()
<-wait
return getCertDuringHandshake(name, true, false)
}
// looks like it's up to us to do all the work and obtain the cert
wait = make(chan struct{})
obtainCertWaitChans[name] = wait
obtainCertWaitChansMu.Unlock()
// Unblock waiters and delete waitgroup when we return
defer func() {
obtainCertWaitChansMu.Lock()
close(wait)
delete(obtainCertWaitChans, name)
obtainCertWaitChansMu.Unlock()
}()
log.Printf("[INFO] Obtaining new certificate for %s", name)
// obtain cert
client, err := NewACMEClientGetEmail(server.Config{}, false)
if err != nil {
return Certificate{}, errors.New("error creating client: " + err.Error())
}
client.Configure("") // TODO: which BindHost?
err = client.Obtain([]string{name})
if err != nil {
// Failed to solve challenge, so don't allow another on-demand
// issue for this name to be attempted for a little while.
failedIssuanceMu.Lock()
failedIssuance[name] = time.Now()
go func(name string) {
time.Sleep(5 * time.Minute)
failedIssuanceMu.Lock()
delete(failedIssuance, name)
failedIssuanceMu.Unlock()
}(name)
failedIssuanceMu.Unlock()
return Certificate{}, err
}
// Success - update counters and stuff
atomic.AddInt32(OnDemandIssuedCount, 1)
lastIssueTimeMu.Lock()
lastIssueTime = time.Now()
lastIssueTimeMu.Unlock()
// The certificate is already on disk; now just start over to load it and serve it
return getCertDuringHandshake(name, true, false)
}
// handshakeMaintenance performs a check on cert for expiration and OCSP
// validity.
//
// This function is safe for use by multiple concurrent goroutines.
func handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
// Check cert expiration
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < renewDurationBefore {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
return renewDynamicCertificate(name)
}
// Check OCSP staple validity
if cert.OCSP != nil {
refreshTime := cert.OCSP.ThisUpdate.Add(cert.OCSP.NextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
if time.Now().After(refreshTime) {
err := stapleOCSP(&cert, nil)
if err != nil {
// An error with OCSP stapling is not the end of the world, and in fact, is
// quite common considering not all certs have issuer URLs that support it.
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
}
certCacheMu.Lock()
certCache[name] = cert
certCacheMu.Unlock()
}
}
return cert, nil
}
// renewDynamicCertificate renews currentCert using the clientHello. It returns the
// certificate to use and an error, if any. currentCert may be returned even if an
// error occurs, since we perform renewals before they expire and it may still be
// usable. name should already be lower-cased before calling this function.
//
// This function is safe for use by multiple concurrent goroutines.
func renewDynamicCertificate(name string) (Certificate, error) {
obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name]
if ok {
// lucky us -- another goroutine is already renewing the certificate.
// wait for it to finish, then we'll use the new one.
obtainCertWaitChansMu.Unlock()
<-wait
return getCertDuringHandshake(name, true, false)
}
// looks like it's up to us to do all the work and renew the cert
wait = make(chan struct{})
obtainCertWaitChans[name] = wait
obtainCertWaitChansMu.Unlock()
// unblock waiters and delete waitgroup when we return
defer func() {
obtainCertWaitChansMu.Lock()
close(wait)
delete(obtainCertWaitChans, name)
obtainCertWaitChansMu.Unlock()
}()
log.Printf("[INFO] Renewing certificate for %s", name)
client, err := NewACMEClientGetEmail(server.Config{}, false)
if err != nil {
return Certificate{}, err
}
client.Configure("") // TODO: Bind address of relevant listener, yuck
err = client.Renew(name)
if err != nil {
return Certificate{}, err
}
return getCertDuringHandshake(name, true, false)
}
// stapleOCSP staples OCSP information to cert for hostname name.
// If you have it handy, you should pass in the PEM-encoded certificate
// bundle; otherwise the DER-encoded cert will have to be PEM-encoded.
// If you don't have the PEM blocks handy, just pass in nil.
//
// Errors here are not necessarily fatal, it could just be that the
// certificate doesn't have an issuer URL.
func stapleOCSP(cert *Certificate, pemBundle []byte) error {
if pemBundle == nil {
// The function in the acme package that gets OCSP requires a PEM-encoded cert
bundle := new(bytes.Buffer)
for _, derBytes := range cert.Certificate.Certificate {
pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
}
pemBundle = bundle.Bytes()
}
ocspBytes, ocspResp, err := acme.GetOCSPForCert(pemBundle)
if err != nil {
return err
}
cert.Certificate.OCSPStaple = ocspBytes
cert.OCSP = ocspResp
return nil
}
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
var obtainCertWaitChans = make(map[string]chan struct{})
var obtainCertWaitChansMu sync.Mutex
// OnDemandIssuedCount is the number of certificates that have been issued
// on-demand by this process. It is only safe to modify this count atomically.
// If it reaches onDemandMaxIssue, on-demand issuances will fail.
var OnDemandIssuedCount = new(int32)
// onDemandMaxIssue is set based on max_certs in tls config. It specifies the
// maximum number of certificates that can be issued.
// TODO: This applies globally, but we should probably make a server-specific
// way to keep track of these limits and counts, since it's specified in the
// Caddyfile...
var onDemandMaxIssue int32
// failedIssuance is a set of names that we recently failed to get a
// certificate for from the ACME CA. They are removed after some time.
// When a name is in this map, do not issue a certificate for it on-demand.
var failedIssuance = make(map[string]time.Time)
var failedIssuanceMu sync.RWMutex
// lastIssueTime records when we last obtained a certificate successfully.
// If this value is recent, do not make any on-demand certificate requests.
var lastIssueTime time.Time
var lastIssueTimeMu sync.Mutex
package https
import (
"crypto/tls"
"crypto/x509"
"testing"
)
func TestGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
helloNoSNI := &tls.ClientHelloInfo{}
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
// When cache is empty
if cert, err := GetCertificate(hello); err == nil {
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
}
if cert, err := GetCertificate(helloNoSNI); err == nil {
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
}
// When cache has one certificate in it (also is default)
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
certCache[""] = defaultCert
certCache["example.com"] = defaultCert
if cert, err := GetCertificate(hello); err != nil {
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
}
if cert, err := GetCertificate(helloNoSNI); err != nil {
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
}
// When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
if cert, err := GetCertificate(helloSub); err != nil {
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
}
// When no certificate matches, the default is returned
if cert, err := GetCertificate(helloNoMatch); err != nil {
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Expected default cert with no matches, got: %v", cert)
}
}
// Package letsencrypt integrates Let's Encrypt functionality into Caddy // Package https facilitates the management of TLS assets and integrates
// with first-class support for creating and renewing certificates // Let's Encrypt functionality into Caddy with first-class support for
// automatically. It is designed to configure sites for HTTPS by default. // creating and renewing certificates automatically. It is designed to
package letsencrypt // configure sites for HTTPS by default.
package https
import ( import (
"encoding/json" "encoding/json"
...@@ -11,11 +12,7 @@ import ( ...@@ -11,11 +12,7 @@ import (
"net/http" "net/http"
"os" "os"
"strings" "strings"
"time"
"golang.org/x/crypto/ocsp"
"github.com/mholt/caddy/caddy/setup"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
"github.com/mholt/caddy/middleware/redirect" "github.com/mholt/caddy/middleware/redirect"
"github.com/mholt/caddy/server" "github.com/mholt/caddy/server"
...@@ -37,34 +34,27 @@ import ( ...@@ -37,34 +34,27 @@ import (
// //
// Also note that calling this function activates asset // Also note that calling this function activates asset
// management automatically, which keeps certificates // management automatically, which keeps certificates
// renewed and OCSP stapling updated. This has the effect // renewed and OCSP stapling updated.
// of causing restarts when assets are updated.
// //
// Activate returns the updated list of configs, since // Activate returns the updated list of configs, since
// some may have been appended, for example, to redirect // some may have been appended, for example, to redirect
// plaintext HTTP requests to their HTTPS counterpart. // plaintext HTTP requests to their HTTPS counterpart.
// This function only appends; it does not prepend or splice. // This function only appends; it does not splice.
func Activate(configs []server.Config) ([]server.Config, error) { func Activate(configs []server.Config) ([]server.Config, error) {
// just in case previous caller forgot... // just in case previous caller forgot...
Deactivate() Deactivate()
// reset cached ocsp from any previous activations
ocspCache = make(map[*[]byte]*ocsp.Response)
// pre-screen each config and earmark the ones that qualify for managed TLS // pre-screen each config and earmark the ones that qualify for managed TLS
MarkQualified(configs) MarkQualified(configs)
// place certificates and keys on disk // place certificates and keys on disk
err := ObtainCerts(configs, "") err := ObtainCerts(configs, true, false)
if err != nil { if err != nil {
return configs, err return configs, err
} }
// update TLS configurations // update TLS configurations
EnableTLS(configs) err = EnableTLS(configs, true)
// enable OCSP stapling (this affects all TLS-enabled configs)
err = StapleOCSP(configs)
if err != nil { if err != nil {
return configs, err return configs, err
} }
...@@ -77,10 +67,13 @@ func Activate(configs []server.Config) ([]server.Config, error) { ...@@ -77,10 +67,13 @@ func Activate(configs []server.Config) ([]server.Config, error) {
// the renewal ticker is reset, so if restarts happen more often than // the renewal ticker is reset, so if restarts happen more often than
// the ticker interval, renewals would never happen. but doing // the ticker interval, renewals would never happen. but doing
// it right away at start guarantees that renewals aren't missed. // it right away at start guarantees that renewals aren't missed.
renewCertificates(configs, false) err = renewManagedCertificates(true)
if err != nil {
return configs, err
}
// keep certificates renewed and OCSP stapling updated // keep certificates renewed and OCSP stapling updated
go maintainAssets(configs, stopChan) go maintainAssets(stopChan)
return configs, nil return configs, nil
} }
...@@ -101,7 +94,7 @@ func Deactivate() (err error) { ...@@ -101,7 +94,7 @@ func Deactivate() (err error) {
} }
// MarkQualified scans each config and, if it qualifies for managed // MarkQualified scans each config and, if it qualifies for managed
// TLS, it sets the Marked field of the TLSConfig to true. // TLS, it sets the Managed field of the TLSConfig to true.
func MarkQualified(configs []server.Config) { func MarkQualified(configs []server.Config) {
for i := 0; i < len(configs); i++ { for i := 0; i < len(configs); i++ {
if ConfigQualifies(configs[i]) { if ConfigQualifies(configs[i]) {
...@@ -110,58 +103,57 @@ func MarkQualified(configs []server.Config) { ...@@ -110,58 +103,57 @@ func MarkQualified(configs []server.Config) {
} }
} }
// ObtainCerts obtains certificates for all these configs as long as a certificate does not // ObtainCerts obtains certificates for all these configs as long as a
// already exist on disk. It does not modify the configs at all; it only obtains and stores // certificate does not already exist on disk. It does not modify the
// certificates and keys to the disk. // configs at all; it only obtains and stores certificates and keys to
func ObtainCerts(configs []server.Config, altPort string) error { // the disk. If allowPrompts is true, the user may be shown a prompt.
groupedConfigs := groupConfigsByEmail(configs, altPort != "") // don't prompt user if server already running // If proxyACME is true, the ACME challenges will be proxied to our alt port.
func ObtainCerts(configs []server.Config, allowPrompts, proxyACME bool) error {
// We group configs by email so we don't make the same clients over and
// over. This has the potential to prompt the user for an email, but we
// prevent that by assuming that if we already have a listener that can
// proxy ACME challenge requests, then the server is already running and
// the operator is no longer present.
groupedConfigs := groupConfigsByEmail(configs, allowPrompts)
for email, group := range groupedConfigs { for email, group := range groupedConfigs {
client, err := newClientPort(email, altPort) // Wait as long as we can before creating the client, because it
if err != nil { // may not be needed, for example, if we already have what we
return errors.New("error creating client: " + err.Error()) // need on disk. Creating a client involves the network and
} // potentially prompting the user, etc., so only do if necessary.
var client *ACMEClient
for _, cfg := range group { for _, cfg := range group {
if existingCertAndKey(cfg.Host) { if !HostQualifies(cfg.Host) || existingCertAndKey(cfg.Host) {
continue continue
} }
Obtain: // Now we definitely do need a client
certificate, failures := client.ObtainCertificate([]string{cfg.Host}, true, nil) if client == nil {
if len(failures) == 0 { var err error
// Success - immediately save the certificate resource client, err = NewACMEClient(email, allowPrompts)
err := saveCertResource(certificate)
if err != nil { if err != nil {
return errors.New("error saving assets for " + cfg.Host + ": " + err.Error()) return errors.New("error creating client: " + err.Error())
} }
}
// c.Configure assumes that allowPrompts == !proxyACME,
// but that's not always true. For example, a restart where
// the user isn't present and we're not listening on port 80.
// TODO: This could probably be refactored better.
if proxyACME {
client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, AlternatePort))
client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, AlternatePort))
client.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01})
} else { } else {
// Error - either try to fix it or report them it to the user and abort client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, ""))
var errMsg string // we'll combine all the failures into a single error message client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, ""))
var promptedForAgreement bool // only prompt user for agreement at most once client.ExcludeChallenges([]acme.Challenge{acme.DNS01})
}
for errDomain, obtainErr := range failures {
// TODO: Double-check, will obtainErr ever be nil?
if tosErr, ok := obtainErr.(acme.TOSError); ok {
// Terms of Service agreement error; we can probably deal with this
if !Agreed && !promptedForAgreement && altPort == "" { // don't prompt if server is already running
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
promptedForAgreement = true
}
if Agreed || altPort != "" {
err := client.AgreeToTOS()
if err != nil {
return errors.New("error agreeing to updated terms: " + err.Error())
}
goto Obtain
}
}
// If user did not agree or it was any other kind of error, just append to the list of errors
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
}
return errors.New(errMsg) err := client.Obtain([]string{cfg.Host})
if err != nil {
return err
} }
} }
} }
...@@ -169,17 +161,17 @@ func ObtainCerts(configs []server.Config, altPort string) error { ...@@ -169,17 +161,17 @@ func ObtainCerts(configs []server.Config, altPort string) error {
return nil return nil
} }
// groupConfigsByEmail groups configs by the email address to be used by its // groupConfigsByEmail groups configs by the email address to be used by an
// ACME client. It only includes configs that are marked as fully managed. // ACME client. It only groups configs that have TLS enabled and that are
// This is the function that may prompt for an email address, unless skipPrompt // marked as Managed. If userPresent is true, the operator MAY be prompted
// is true, in which case it will assume an empty email address. // for an email address.
func groupConfigsByEmail(configs []server.Config, skipPrompt bool) map[string][]server.Config { func groupConfigsByEmail(configs []server.Config, userPresent bool) map[string][]server.Config {
initMap := make(map[string][]server.Config) initMap := make(map[string][]server.Config)
for _, cfg := range configs { for _, cfg := range configs {
if !cfg.TLS.Managed { if !cfg.TLS.Managed {
continue continue
} }
leEmail := getEmail(cfg, skipPrompt) leEmail := getEmail(cfg, userPresent)
initMap[leEmail] = append(initMap[leEmail], cfg) initMap[leEmail] = append(initMap[leEmail], cfg)
} }
return initMap return initMap
...@@ -187,48 +179,24 @@ func groupConfigsByEmail(configs []server.Config, skipPrompt bool) map[string][] ...@@ -187,48 +179,24 @@ func groupConfigsByEmail(configs []server.Config, skipPrompt bool) map[string][]
// EnableTLS configures each config to use TLS according to default settings. // EnableTLS configures each config to use TLS according to default settings.
// It will only change configs that are marked as managed, and assumes that // It will only change configs that are marked as managed, and assumes that
// certificates and keys are already on disk. // certificates and keys are already on disk. If loadCertificates is true,
func EnableTLS(configs []server.Config) { // the certificates will be loaded from disk into the cache for this process
// to use. If false, TLS will still be enabled and configured with default
// settings, but no certificates will be parsed loaded into the cache, and
// the returned error value will always be nil.
func EnableTLS(configs []server.Config, loadCertificates bool) error {
for i := 0; i < len(configs); i++ { for i := 0; i < len(configs); i++ {
if !configs[i].TLS.Managed { if !configs[i].TLS.Managed {
continue continue
} }
configs[i].TLS.Enabled = true configs[i].TLS.Enabled = true
configs[i].TLS.Certificate = storage.SiteCertFile(configs[i].Host) if loadCertificates && HostQualifies(configs[i].Host) {
configs[i].TLS.Key = storage.SiteKeyFile(configs[i].Host) _, err := cacheManagedCertificate(configs[i].Host, false)
setup.SetDefaultTLSParams(&configs[i]) if err != nil {
} return err
}
// StapleOCSP staples OCSP responses to each config according to their certificate.
// This should work for any TLS-enabled config, not just Let's Encrypt ones.
func StapleOCSP(configs []server.Config) error {
for i := 0; i < len(configs); i++ {
if configs[i].TLS.Certificate == "" {
continue
}
bundleBytes, err := ioutil.ReadFile(configs[i].TLS.Certificate)
if err != nil {
return errors.New("load certificate to staple ocsp: " + err.Error())
}
ocspBytes, ocspResp, err := acme.GetOCSPForCert(bundleBytes)
if err == nil {
// TODO: We ignore the error if it exists because some certificates
// may not have an issuer URL which we should ignore anyway, and
// sometimes we get syntax errors in the responses. To reproduce this
// behavior, start Caddy with an empty Caddyfile and -log stderr. Then
// add a host to the Caddyfile which requires a new LE certificate.
// Reload Caddy's config with SIGUSR1, and see the log report that it
// obtains the certificate, but then an error:
// getting ocsp: asn1: syntax error: sequence truncated
// But retrying the reload again sometimes solves the problem. It's flaky...
ocspCache[&bundleBytes] = ocspResp
if ocspResp.Status == ocsp.Good {
configs[i].TLS.OCSPStaple = ocspBytes
} }
} }
setDefaultTLSParams(&configs[i])
} }
return nil return nil
} }
...@@ -266,28 +234,29 @@ func MakePlaintextRedirects(allConfigs []server.Config) []server.Config { ...@@ -266,28 +234,29 @@ func MakePlaintextRedirects(allConfigs []server.Config) []server.Config {
} }
// ConfigQualifies returns true if cfg qualifies for // ConfigQualifies returns true if cfg qualifies for
// fully managed TLS. It does NOT check to see if a // fully managed TLS (but not on-demand TLS, which is
// not considered here). It does NOT check to see if a
// cert and key already exist for the config. If the // cert and key already exist for the config. If the
// config does qualify, you should set cfg.TLS.Managed // config does qualify, you should set cfg.TLS.Managed
// to true and use that instead, because the process of // to true and check that instead, because the process of
// setting up the config may make it look like it // setting up the config may make it look like it
// doesn't qualify even though it originally did. // doesn't qualify even though it originally did.
func ConfigQualifies(cfg server.Config) bool { func ConfigQualifies(cfg server.Config) bool {
return cfg.TLS.Certificate == "" && // user could provide their own cert and key return (!cfg.TLS.Manual || cfg.TLS.OnDemand) && // user might provide own cert and key
cfg.TLS.Key == "" &&
// user can force-disable automatic HTTPS for this host // user can force-disable automatic HTTPS for this host
cfg.Scheme != "http" && cfg.Scheme != "http" &&
cfg.Port != "80" && cfg.Port != "80" &&
cfg.TLS.LetsEncryptEmail != "off" && cfg.TLS.LetsEncryptEmail != "off" &&
// we get can't certs for some kinds of hostnames // we get can't certs for some kinds of hostnames, but
HostQualifies(cfg.Host) // on-demand TLS allows empty hostnames at startup
(HostQualifies(cfg.Host) || cfg.TLS.OnDemand)
} }
// HostQualifies returns true if the hostname alone // HostQualifies returns true if the hostname alone
// appears eligible for automatic HTTPS. For example, // appears eligible for automatic HTTPS. For example,
// localhost, empty hostname, and wildcard hosts are // localhost, empty hostname, and IP addresses are
// not eligible because we cannot obtain certificates // not eligible because we cannot obtain certificates
// for those names. // for those names.
func HostQualifies(hostname string) bool { func HostQualifies(hostname string) bool {
...@@ -317,71 +286,6 @@ func existingCertAndKey(host string) bool { ...@@ -317,71 +286,6 @@ func existingCertAndKey(host string) bool {
return true return true
} }
// newClient creates a new ACME client to facilitate communication
// with the Let's Encrypt CA server on behalf of the user specified
// by leEmail. As part of this process, a user will be loaded from
// disk (if already exists) or created new and registered via ACME
// and saved to the file system for next time.
func newClient(leEmail string) (*acme.Client, error) {
return newClientPort(leEmail, "")
}
// newClientPort does the same thing as newClient, except it creates a
// new client with a custom port used for ACME transactions instead of
// the default port. This is important if the default port is already in
// use or is not exposed to the public, etc.
func newClientPort(leEmail, port string) (*acme.Client, error) {
// Look up or create the LE user account
leUser, err := getUser(leEmail)
if err != nil {
return nil, err
}
// The client facilitates our communication with the CA server.
client, err := acme.NewClient(CAUrl, &leUser, rsaKeySizeToUse)
if err != nil {
return nil, err
}
if port != "" {
client.SetHTTPAddress(":" + port)
client.SetTLSAddress(":" + port)
}
client.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) // We can only guarantee http-01 at this time, but tls-01 should work if port is not custom!
// If not registered, the user must register an account with the CA
// and agree to terms
if leUser.Registration == nil {
reg, err := client.Register()
if err != nil {
return nil, errors.New("registration error: " + err.Error())
}
leUser.Registration = reg
if port == "" { // can't prompt a user who isn't there
if !Agreed && reg.TosURL == "" {
Agreed = promptUserAgreement(saURL, false) // TODO - latest URL
}
if !Agreed && reg.TosURL == "" {
return nil, errors.New("user must agree to terms")
}
}
err = client.AgreeToTOS()
if err != nil {
saveUser(leUser) // TODO: Might as well try, right? Error check?
return nil, errors.New("error agreeing to terms: " + err.Error())
}
// save user to the file system
err = saveUser(leUser)
if err != nil {
return nil, errors.New("could not save user: " + err.Error())
}
}
return client, nil
}
// saveCertResource saves the certificate resource to disk. This // saveCertResource saves the certificate resource to disk. This
// includes the certificate file itself, the private key, and the // includes the certificate file itself, the private key, and the
// metadata file. // metadata file.
...@@ -421,7 +325,7 @@ func saveCertResource(cert acme.CertificateResource) error { ...@@ -421,7 +325,7 @@ func saveCertResource(cert acme.CertificateResource) error {
// be the HTTPS configuration. The returned configuration is set // be the HTTPS configuration. The returned configuration is set
// to listen on port 80. // to listen on port 80.
func redirPlaintextHost(cfg server.Config) server.Config { func redirPlaintextHost(cfg server.Config) server.Config {
toURL := "https://" + cfg.Host toURL := "https://{host}" // serve any host, since cfg.Host could be empty
if cfg.Port != "443" && cfg.Port != "80" { if cfg.Port != "443" && cfg.Port != "80" {
toURL += ":" + cfg.Port toURL += ":" + cfg.Port
} }
...@@ -438,12 +342,10 @@ func redirPlaintextHost(cfg server.Config) server.Config { ...@@ -438,12 +342,10 @@ func redirPlaintextHost(cfg server.Config) server.Config {
} }
return server.Config{ return server.Config{
Host: cfg.Host, Host: cfg.Host,
BindHost: cfg.BindHost, BindHost: cfg.BindHost,
Port: "80", Port: "80",
Middleware: map[string][]middleware.Middleware{ Middleware: []middleware.Middleware{redirMidware},
"/": []middleware.Middleware{redirMidware},
},
} }
} }
...@@ -453,12 +355,12 @@ func Revoke(host string) error { ...@@ -453,12 +355,12 @@ func Revoke(host string) error {
return errors.New("no certificate and key for " + host) return errors.New("no certificate and key for " + host)
} }
email := getEmail(server.Config{Host: host}, false) email := getEmail(server.Config{Host: host}, true)
if email == "" { if email == "" {
return errors.New("email is required to revoke") return errors.New("email is required to revoke")
} }
client, err := newClient(email) client, err := NewACMEClient(email, true)
if err != nil { if err != nil {
return err return err
} }
...@@ -493,42 +395,17 @@ var ( ...@@ -493,42 +395,17 @@ var (
CAUrl string CAUrl string
) )
// Some essential values related to the Let's Encrypt process // AlternatePort is the port on which the acme client will open a
const ( // listener and solve the CA's challenges. If this alternate port
// AlternatePort is the port on which the acme client will open a // is used instead of the default port (80 or 443), then the
// listener and solve the CA's challenges. If this alternate port // default port for the challenge must be forwarded to this one.
// is used instead of the default port (80 or 443), then the const AlternatePort = "5033"
// default port for the challenge must be forwarded to this one.
AlternatePort = "5033"
// RenewInterval is how often to check certificates for renewal.
RenewInterval = 24 * time.Hour
// OCSPInterval is how often to check if OCSP stapling needs updating.
OCSPInterval = 1 * time.Hour
)
// KeySize represents the length of a key in bits.
type KeySize int
// Key sizes are used to determine the strength of a key. // KeyType is the type to use for new keys.
const (
Ecc224 KeySize = 224
Ecc256 = 256
Rsa2048 = 2048
Rsa4096 = 4096
)
// rsaKeySizeToUse is the size to use for new RSA keys.
// This shouldn't need to change except for in tests; // This shouldn't need to change except for in tests;
// the size can be drastically reduced for speed. // the size can be drastically reduced for speed.
var rsaKeySizeToUse = Rsa2048 var KeyType = acme.EC384
// stopChan is used to signal the maintenance goroutine // stopChan is used to signal the maintenance goroutine
// to terminate. // to terminate.
var stopChan chan struct{} var stopChan chan struct{}
// ocspCache maps certificate bundle to OCSP response.
// It is used during regular OCSP checks to see if the OCSP
// response needs to be updated.
var ocspCache = make(map[*[]byte]*ocsp.Response)
package letsencrypt package https
import ( import (
"io/ioutil" "io/ioutil"
...@@ -46,10 +46,11 @@ func TestConfigQualifies(t *testing.T) { ...@@ -46,10 +46,11 @@ func TestConfigQualifies(t *testing.T) {
cfg server.Config cfg server.Config
expect bool expect bool
}{ }{
{server.Config{Host: ""}, false},
{server.Config{Host: "localhost"}, false}, {server.Config{Host: "localhost"}, false},
{server.Config{Host: "123.44.3.21"}, false},
{server.Config{Host: "example.com"}, true}, {server.Config{Host: "example.com"}, true},
{server.Config{Host: "example.com", TLS: server.TLSConfig{Certificate: "cert.pem"}}, false}, {server.Config{Host: "example.com", TLS: server.TLSConfig{Manual: true}}, false},
{server.Config{Host: "example.com", TLS: server.TLSConfig{Key: "key.pem"}}, false},
{server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false}, {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false},
{server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true}, {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true},
{server.Config{Host: "example.com", Scheme: "http"}, false}, {server.Config{Host: "example.com", Scheme: "http"}, false},
...@@ -86,11 +87,11 @@ func TestRedirPlaintextHost(t *testing.T) { ...@@ -86,11 +87,11 @@ func TestRedirPlaintextHost(t *testing.T) {
} }
// Make sure redirect handler is set up properly // Make sure redirect handler is set up properly
if cfg.Middleware == nil || len(cfg.Middleware["/"]) != 1 { if cfg.Middleware == nil || len(cfg.Middleware) != 1 {
t.Fatalf("Redir config middleware not set up properly; got: %#v", cfg.Middleware) t.Fatalf("Redir config middleware not set up properly; got: %#v", cfg.Middleware)
} }
handler, ok := cfg.Middleware["/"][0](nil).(redirect.Redirect) handler, ok := cfg.Middleware[0](nil).(redirect.Redirect)
if !ok { if !ok {
t.Fatalf("Expected a redirect.Redirect middleware, but got: %#v", handler) t.Fatalf("Expected a redirect.Redirect middleware, but got: %#v", handler)
} }
...@@ -105,18 +106,18 @@ func TestRedirPlaintextHost(t *testing.T) { ...@@ -105,18 +106,18 @@ func TestRedirPlaintextHost(t *testing.T) {
if actual, expected := handler.Rules[0].FromPath, "/"; actual != expected { if actual, expected := handler.Rules[0].FromPath, "/"; actual != expected {
t.Errorf("Expected redirect rule to be for path '%s' but is actually for '%s'", expected, actual) t.Errorf("Expected redirect rule to be for path '%s' but is actually for '%s'", expected, actual)
} }
if actual, expected := handler.Rules[0].To, "https://example.com:1234{uri}"; actual != expected { if actual, expected := handler.Rules[0].To, "https://{host}:1234{uri}"; actual != expected {
t.Errorf("Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual) t.Errorf("Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual)
} }
if actual, expected := handler.Rules[0].Code, http.StatusMovedPermanently; actual != expected { if actual, expected := handler.Rules[0].Code, http.StatusMovedPermanently; actual != expected {
t.Errorf("Expected redirect rule to have code %d but was %d", expected, actual) t.Errorf("Expected redirect rule to have code %d but was %d", expected, actual)
} }
// browsers can interpret default ports with scheme, so make sure the port // browsers can infer a default port from scheme, so make sure the port
// doesn't get added in explicitly for default ports. // doesn't get added in explicitly for default ports like 443 for https.
cfg = redirPlaintextHost(server.Config{Host: "example.com", Port: "443"}) cfg = redirPlaintextHost(server.Config{Host: "example.com", Port: "443"})
handler, ok = cfg.Middleware["/"][0](nil).(redirect.Redirect) handler, ok = cfg.Middleware[0](nil).(redirect.Redirect)
if actual, expected := handler.Rules[0].To, "https://example.com{uri}"; actual != expected { if actual, expected := handler.Rules[0].To, "https://{host}{uri}"; actual != expected {
t.Errorf("(Default Port) Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual) t.Errorf("(Default Port) Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual)
} }
} }
...@@ -208,9 +209,9 @@ func TestExistingCertAndKey(t *testing.T) { ...@@ -208,9 +209,9 @@ func TestExistingCertAndKey(t *testing.T) {
func TestHostHasOtherPort(t *testing.T) { func TestHostHasOtherPort(t *testing.T) {
configs := []server.Config{ configs := []server.Config{
server.Config{Host: "example.com", Port: "80"}, {Host: "example.com", Port: "80"},
server.Config{Host: "sub1.example.com", Port: "80"}, {Host: "sub1.example.com", Port: "80"},
server.Config{Host: "sub1.example.com", Port: "443"}, {Host: "sub1.example.com", Port: "443"},
} }
if hostHasOtherPort(configs, 0, "80") { if hostHasOtherPort(configs, 0, "80") {
...@@ -227,18 +228,18 @@ func TestHostHasOtherPort(t *testing.T) { ...@@ -227,18 +228,18 @@ func TestHostHasOtherPort(t *testing.T) {
func TestMakePlaintextRedirects(t *testing.T) { func TestMakePlaintextRedirects(t *testing.T) {
configs := []server.Config{ configs := []server.Config{
// Happy path = standard redirect from 80 to 443 // Happy path = standard redirect from 80 to 443
server.Config{Host: "example.com", TLS: server.TLSConfig{Managed: true}}, {Host: "example.com", TLS: server.TLSConfig{Managed: true}},
// Host on port 80 already defined; don't change it (no redirect) // Host on port 80 already defined; don't change it (no redirect)
server.Config{Host: "sub1.example.com", Port: "80", Scheme: "http"}, {Host: "sub1.example.com", Port: "80", Scheme: "http"},
server.Config{Host: "sub1.example.com", TLS: server.TLSConfig{Managed: true}}, {Host: "sub1.example.com", TLS: server.TLSConfig{Managed: true}},
// Redirect from port 80 to port 5000 in this case // Redirect from port 80 to port 5000 in this case
server.Config{Host: "sub2.example.com", Port: "5000", TLS: server.TLSConfig{Managed: true}}, {Host: "sub2.example.com", Port: "5000", TLS: server.TLSConfig{Managed: true}},
// Can redirect from 80 to either 443 or 5001, but choose 443 // Can redirect from 80 to either 443 or 5001, but choose 443
server.Config{Host: "sub3.example.com", Port: "443", TLS: server.TLSConfig{Managed: true}}, {Host: "sub3.example.com", Port: "443", TLS: server.TLSConfig{Managed: true}},
server.Config{Host: "sub3.example.com", Port: "5001", Scheme: "https", TLS: server.TLSConfig{Managed: true}}, {Host: "sub3.example.com", Port: "5001", Scheme: "https", TLS: server.TLSConfig{Managed: true}},
} }
result := MakePlaintextRedirects(configs) result := MakePlaintextRedirects(configs)
...@@ -252,31 +253,18 @@ func TestMakePlaintextRedirects(t *testing.T) { ...@@ -252,31 +253,18 @@ func TestMakePlaintextRedirects(t *testing.T) {
func TestEnableTLS(t *testing.T) { func TestEnableTLS(t *testing.T) {
configs := []server.Config{ configs := []server.Config{
server.Config{TLS: server.TLSConfig{Managed: true}}, {Host: "example.com", TLS: server.TLSConfig{Managed: true}},
server.Config{}, // not managed - no changes! {}, // not managed - no changes!
} }
EnableTLS(configs) EnableTLS(configs, false)
if !configs[0].TLS.Enabled { if !configs[0].TLS.Enabled {
t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false") t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false")
} }
if configs[0].TLS.Certificate == "" {
t.Errorf("Expected config 0 to have TLS.Certificate set, but it was empty")
}
if configs[0].TLS.Key == "" {
t.Errorf("Expected config 0 to have TLS.Key set, but it was empty")
}
if configs[1].TLS.Enabled { if configs[1].TLS.Enabled {
t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true") t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true")
} }
if configs[1].TLS.Certificate != "" {
t.Errorf("Expected config 1 to have TLS.Certificate empty, but it was: %s", configs[1].TLS.Certificate)
}
if configs[1].TLS.Key != "" {
t.Errorf("Expected config 1 to have TLS.Key empty, but it was: %s", configs[1].TLS.Key)
}
} }
func TestGroupConfigsByEmail(t *testing.T) { func TestGroupConfigsByEmail(t *testing.T) {
...@@ -285,12 +273,12 @@ func TestGroupConfigsByEmail(t *testing.T) { ...@@ -285,12 +273,12 @@ func TestGroupConfigsByEmail(t *testing.T) {
} }
configs := []server.Config{ configs := []server.Config{
server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
server.Config{Host: "sub1.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}}, {Host: "sub1.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}},
server.Config{Host: "sub2.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, {Host: "sub2.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
server.Config{Host: "sub3.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}}, {Host: "sub3.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}},
server.Config{Host: "sub4.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, {Host: "sub4.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
server.Config{Host: "sub5.example.com", TLS: server.TLSConfig{LetsEncryptEmail: ""}}, // not managed {Host: "sub5.example.com", TLS: server.TLSConfig{LetsEncryptEmail: ""}}, // not managed
} }
DefaultEmail = "test@example.com" DefaultEmail = "test@example.com"
...@@ -314,10 +302,11 @@ func TestGroupConfigsByEmail(t *testing.T) { ...@@ -314,10 +302,11 @@ func TestGroupConfigsByEmail(t *testing.T) {
func TestMarkQualified(t *testing.T) { func TestMarkQualified(t *testing.T) {
// TODO: TestConfigQualifies and this test share the same config list... // TODO: TestConfigQualifies and this test share the same config list...
configs := []server.Config{ configs := []server.Config{
{Host: ""},
{Host: "localhost"}, {Host: "localhost"},
{Host: "123.44.3.21"},
{Host: "example.com"}, {Host: "example.com"},
{Host: "example.com", TLS: server.TLSConfig{Certificate: "cert.pem"}}, {Host: "example.com", TLS: server.TLSConfig{Manual: true}},
{Host: "example.com", TLS: server.TLSConfig{Key: "key.pem"}},
{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}},
{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}},
{Host: "example.com", Scheme: "http"}, {Host: "example.com", Scheme: "http"},
......
package https
import (
"log"
"time"
"github.com/mholt/caddy/server"
"golang.org/x/crypto/ocsp"
)
const (
// RenewInterval is how often to check certificates for renewal.
RenewInterval = 12 * time.Hour
// OCSPInterval is how often to check if OCSP stapling needs updating.
OCSPInterval = 1 * time.Hour
)
// maintainAssets is a permanently-blocking function
// that loops indefinitely and, on a regular schedule, checks
// certificates for expiration and initiates a renewal of certs
// that are expiring soon. It also updates OCSP stapling and
// performs other maintenance of assets.
//
// You must pass in the channel which you'll close when
// maintenance should stop, to allow this goroutine to clean up
// after itself and unblock.
func maintainAssets(stopChan chan struct{}) {
renewalTicker := time.NewTicker(RenewInterval)
ocspTicker := time.NewTicker(OCSPInterval)
for {
select {
case <-renewalTicker.C:
log.Println("[INFO] Scanning for expiring certificates")
renewManagedCertificates(false)
log.Println("[INFO] Done checking certificates")
case <-ocspTicker.C:
log.Println("[INFO] Scanning for stale OCSP staples")
updateOCSPStaples()
log.Println("[INFO] Done checking OCSP staples")
case <-stopChan:
renewalTicker.Stop()
ocspTicker.Stop()
log.Println("[INFO] Stopped background maintenance routine")
return
}
}
}
func renewManagedCertificates(allowPrompts bool) (err error) {
var renewed, deleted []Certificate
var client *ACMEClient
visitedNames := make(map[string]struct{})
certCacheMu.RLock()
for name, cert := range certCache {
if !cert.Managed {
continue
}
// the list of names on this cert should never be empty...
if cert.Names == nil || len(cert.Names) == 0 {
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v", name, cert.Names)
deleted = append(deleted, cert)
continue
}
// skip names whose certificate we've already renewed
if _, ok := visitedNames[name]; ok {
continue
}
for _, name := range cert.Names {
visitedNames[name] = struct{}{}
}
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < renewDurationBefore {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
if client == nil {
client, err = NewACMEClientGetEmail(server.Config{}, allowPrompts)
if err != nil {
return err
}
client.Configure("") // TODO: Bind address of relevant listener, yuck
}
err := client.Renew(cert.Names[0]) // managed certs better have only one name
if err != nil {
if client.AllowPrompts && timeLeft < 0 {
// Certificate renewal failed, the operator is present, and the certificate
// is already expired; we should stop immediately and return the error. Note
// that we used to do this any time a renewal failed at startup. However,
// after discussion in https://github.com/mholt/caddy/issues/642 we decided to
// only stop startup if the certificate is expired. We still log the error
// otherwise.
certCacheMu.RUnlock()
return err
}
log.Printf("[ERROR] %v", err)
if cert.OnDemand {
deleted = append(deleted, cert)
}
} else {
renewed = append(renewed, cert)
}
}
}
certCacheMu.RUnlock()
// Apply changes to the cache
for _, cert := range renewed {
_, err := cacheManagedCertificate(cert.Names[0], cert.OnDemand)
if err != nil {
if client.AllowPrompts {
return err // operator is present, so report error immediately
}
log.Printf("[ERROR] %v", err)
}
}
for _, cert := range deleted {
certCacheMu.Lock()
for _, name := range cert.Names {
delete(certCache, name)
}
certCacheMu.Unlock()
}
return nil
}
func updateOCSPStaples() {
// Create a temporary place to store updates
// until we release the potentially long-lived
// read lock and use a short-lived write lock.
type ocspUpdate struct {
rawBytes []byte
parsed *ocsp.Response
}
updated := make(map[string]ocspUpdate)
// A single SAN certificate maps to multiple names, so we use this
// set to make sure we don't waste cycles checking OCSP for the same
// certificate multiple times.
visited := make(map[string]struct{})
certCacheMu.RLock()
for name, cert := range certCache {
// skip this certificate if we've already visited it,
// and if not, mark all the names as visited
if _, ok := visited[name]; ok {
continue
}
for _, n := range cert.Names {
visited[n] = struct{}{}
}
// no point in updating OCSP for expired certificates
if time.Now().After(cert.NotAfter) {
continue
}
var lastNextUpdate time.Time
if cert.OCSP != nil {
// start checking OCSP staple about halfway through validity period for good measure
lastNextUpdate = cert.OCSP.NextUpdate
refreshTime := cert.OCSP.ThisUpdate.Add(lastNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
// since OCSP is already stapled, we need only check if we're in that "refresh window"
if time.Now().Before(refreshTime) {
continue
}
}
err := stapleOCSP(&cert, nil)
if err != nil {
if cert.OCSP != nil {
// if it was no staple before, that's fine, otherwise we should log the error
log.Printf("[ERROR] Checking OCSP for %s: %v", name, err)
}
continue
}
// By this point, we've obtained the latest OCSP response.
// If there was no staple before, or if the response is updated, make
// sure we apply the update to all names on the certificate.
if lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate {
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
for _, n := range cert.Names {
updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
}
}
}
certCacheMu.RUnlock()
// This write lock should be brief since we have all the info we need now.
certCacheMu.Lock()
for name, update := range updated {
cert := certCache[name]
cert.OCSP = update.parsed
cert.Certificate.OCSPStaple = update.rawBytes
certCache[name] = cert
}
certCacheMu.Unlock()
}
// renewDurationBefore is how long before expiration to renew certificates.
const renewDurationBefore = (24 * time.Hour) * 30
package setup package https
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"encoding/pem"
"io/ioutil"
"log" "log"
"os"
"path/filepath"
"strconv"
"strings" "strings"
"github.com/mholt/caddy/caddy/setup"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
"github.com/mholt/caddy/server" "github.com/mholt/caddy/server"
) )
// TLS sets up the TLS configuration (but does not activate Let's Encrypt; that is handled elsewhere). // Setup sets up the TLS configuration and installs certificates that
func TLS(c *Controller) (middleware.Middleware, error) { // are specified by the user in the config file. All the automatic HTTPS
if c.Scheme == "http" { // stuff comes later outside of this function.
func Setup(c *setup.Controller) (middleware.Middleware, error) {
if c.Port == "80" || c.Scheme == "http" {
c.TLS.Enabled = false c.TLS.Enabled = false
log.Printf("[WARNING] TLS disabled for %s://%s.", c.Scheme, c.Address()) log.Printf("[WARNING] TLS disabled for %s://%s.", c.Scheme, c.Address())
} else { return nil, nil
c.TLS.Enabled = true
} }
c.TLS.Enabled = true
for c.Next() { for c.Next() {
var certificateFile, keyFile, loadDir, maxCerts string
args := c.RemainingArgs() args := c.RemainingArgs()
switch len(args) { switch len(args) {
case 1: case 1:
c.TLS.LetsEncryptEmail = args[0] c.TLS.LetsEncryptEmail = args[0]
// user can force-disable LE activation this way // user can force-disable managed TLS this way
if c.TLS.LetsEncryptEmail == "off" { if c.TLS.LetsEncryptEmail == "off" {
c.TLS.Enabled = false c.TLS.Enabled = false
return nil, nil
} }
case 2: case 2:
c.TLS.Certificate = args[0] certificateFile = args[0]
c.TLS.Key = args[1] keyFile = args[1]
c.TLS.Manual = true
} }
// Optional block with extra parameters // Optional block with extra parameters
...@@ -66,9 +79,12 @@ func TLS(c *Controller) (middleware.Middleware, error) { ...@@ -66,9 +79,12 @@ func TLS(c *Controller) (middleware.Middleware, error) {
if len(c.TLS.ClientCerts) == 0 { if len(c.TLS.ClientCerts) == 0 {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
// TODO: Allow this? It's a bad idea to allow HTTP. If we do this, make sure invoking tls at all (even manually) also sets up a redirect if possible? case "load":
// case "allow_http": c.Args(&loadDir)
// c.TLS.DisableHTTPRedir = true c.TLS.Manual = true
case "max_certs":
c.Args(&maxCerts)
c.TLS.OnDemand = true
default: default:
return nil, c.Errf("Unknown keyword '%s'", c.Val()) return nil, c.Errf("Unknown keyword '%s'", c.Val())
} }
...@@ -78,24 +94,140 @@ func TLS(c *Controller) (middleware.Middleware, error) { ...@@ -78,24 +94,140 @@ func TLS(c *Controller) (middleware.Middleware, error) {
if len(args) == 0 && !hadBlock { if len(args) == 0 && !hadBlock {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
// set certificate limit if on-demand TLS is enabled
if maxCerts != "" {
maxCertsNum, err := strconv.Atoi(maxCerts)
if err != nil || maxCertsNum < 1 {
return nil, c.Err("max_certs must be a positive integer")
}
if onDemandMaxIssue == 0 || int32(maxCertsNum) < onDemandMaxIssue { // keep the minimum; TODO: We have to do this because it is global; should be per-server or per-vhost...
onDemandMaxIssue = int32(maxCertsNum)
}
}
// don't try to load certificates unless we're supposed to
if !c.TLS.Enabled || !c.TLS.Manual {
continue
}
// load a single certificate and key, if specified
if certificateFile != "" && keyFile != "" {
err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
if err != nil {
return nil, c.Errf("Unable to load certificate and key files for %s: %v", c.Host, err)
}
log.Printf("[INFO] Successfully loaded TLS assets from %s and %s", certificateFile, keyFile)
}
// load a directory of certificates, if specified
if loadDir != "" {
err := loadCertsInDir(c, loadDir)
if err != nil {
return nil, err
}
}
} }
SetDefaultTLSParams(c.Config) setDefaultTLSParams(c.Config)
return nil, nil return nil, nil
} }
// SetDefaultTLSParams sets the default TLS cipher suites, protocol versions, // loadCertsInDir loads all the certificates/keys in dir, as long as
// the file ends with .pem. This method of loading certificates is
// modeled after haproxy, which expects the certificate and key to
// be bundled into the same file:
// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
//
// This function may write to the log as it walks the directory tree.
func loadCertsInDir(c *setup.Controller, dir string) error {
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
return nil
}
if info.IsDir() {
return nil
}
if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") {
certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer)
var foundKey bool // use only the first key in the file
bundle, err := ioutil.ReadFile(path)
if err != nil {
return err
}
for {
// Decode next block so we can see what type it is
var derBlock *pem.Block
derBlock, bundle = pem.Decode(bundle)
if derBlock == nil {
break
}
if derBlock.Type == "CERTIFICATE" {
// Re-encode certificate as PEM, appending to certificate chain
pem.Encode(certBuilder, derBlock)
} else if derBlock.Type == "EC PARAMETERS" {
// EC keys generated from openssl can be composed of two blocks:
// parameters and key (parameter block should come first)
if !foundKey {
// Encode parameters
pem.Encode(keyBuilder, derBlock)
// Key must immediately follow
derBlock, bundle = pem.Decode(bundle)
if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" {
return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path)
}
pem.Encode(keyBuilder, derBlock)
foundKey = true
}
} else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") {
// RSA key
if !foundKey {
pem.Encode(keyBuilder, derBlock)
foundKey = true
}
} else {
return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type)
}
}
certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes()
if len(certPEMBytes) == 0 {
return c.Errf("%s: failed to parse PEM data", path)
}
if len(keyPEMBytes) == 0 {
return c.Errf("%s: no private key block found", path)
}
err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
if err != nil {
return c.Errf("%s: failed to load cert and key for %s: %v", path, c.Host, err)
}
log.Printf("[INFO] Successfully loaded TLS assets from %s", path)
}
return nil
})
}
// setDefaultTLSParams sets the default TLS cipher suites, protocol versions,
// and server preferences of a server.Config if they were not previously set // and server preferences of a server.Config if they were not previously set
// (it does not overwrite; only fills in missing values). // (it does not overwrite; only fills in missing values). It will also set the
func SetDefaultTLSParams(c *server.Config) { // port to 443 if not already set, TLS is enabled, TLS is manual, and the host
// If no ciphers provided, use all that Caddy supports for the protocol // does not equal localhost.
func setDefaultTLSParams(c *server.Config) {
// If no ciphers provided, use default list
if len(c.TLS.Ciphers) == 0 { if len(c.TLS.Ciphers) == 0 {
c.TLS.Ciphers = defaultCiphers c.TLS.Ciphers = defaultCiphers
} }
// Not a cipher suite, but still important for mitigating protocol downgrade attacks // Not a cipher suite, but still important for mitigating protocol downgrade attacks
c.TLS.Ciphers = append(c.TLS.Ciphers, tls.TLS_FALLBACK_SCSV) // (prepend since having it at end breaks http2 due to non-h2-approved suites before it)
c.TLS.Ciphers = append([]uint16{tls.TLS_FALLBACK_SCSV}, c.TLS.Ciphers...)
// Set default protocol min and max versions - must balance compatibility and security // Set default protocol min and max versions - must balance compatibility and security
if c.TLS.ProtocolMinVersion == 0 { if c.TLS.ProtocolMinVersion == 0 {
...@@ -110,14 +242,14 @@ func SetDefaultTLSParams(c *server.Config) { ...@@ -110,14 +242,14 @@ func SetDefaultTLSParams(c *server.Config) {
// Default TLS port is 443; only use if port is not manually specified, // Default TLS port is 443; only use if port is not manually specified,
// TLS is enabled, and the host is not localhost // TLS is enabled, and the host is not localhost
if c.Port == "" && c.TLS.Enabled && c.Host != "localhost" { if c.Port == "" && c.TLS.Enabled && (!c.TLS.Manual || c.TLS.OnDemand) && c.Host != "localhost" {
c.Port = "443" c.Port = "443"
} }
} }
// Map of supported protocols // Map of supported protocols.
// SSLv3 will be not supported in future release // SSLv3 will be not supported in future release.
// HTTP/2 only supports TLS 1.2 and higher // HTTP/2 only supports TLS 1.2 and higher.
var supportedProtocols = map[string]uint16{ var supportedProtocols = map[string]uint16{
"ssl3.0": tls.VersionSSL30, "ssl3.0": tls.VersionSSL30,
"tls1.0": tls.VersionTLS10, "tls1.0": tls.VersionTLS10,
...@@ -136,6 +268,8 @@ var supportedProtocols = map[string]uint16{ ...@@ -136,6 +268,8 @@ var supportedProtocols = map[string]uint16{
// //
// This map, like any map, is NOT ORDERED. Do not range over this map. // This map, like any map, is NOT ORDERED. Do not range over this map.
var supportedCiphersMap = map[string]uint16{ var supportedCiphersMap = map[string]uint16{
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
"ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
...@@ -155,6 +289,8 @@ var supportedCiphersMap = map[string]uint16{ ...@@ -155,6 +289,8 @@ var supportedCiphersMap = map[string]uint16{
// Note that TLS_FALLBACK_SCSV is not in this list since it is always // Note that TLS_FALLBACK_SCSV is not in this list since it is always
// added manually. // added manually.
var supportedCiphers = []uint16{ var supportedCiphers = []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
...@@ -169,6 +305,8 @@ var supportedCiphers = []uint16{ ...@@ -169,6 +305,8 @@ var supportedCiphers = []uint16{
// List of all the ciphers we want to use by default // List of all the ciphers we want to use by default
var defaultCiphers = []uint16{ var defaultCiphers = []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
......
package setup package https
import ( import (
"crypto/tls" "crypto/tls"
"io/ioutil"
"log"
"os"
"testing" "testing"
"github.com/mholt/caddy/caddy/setup"
) )
func TestTLSParseBasic(t *testing.T) { func TestMain(m *testing.M) {
c := NewTestController(`tls cert.pem key.pem`) // Write test certificates to disk before tests, and clean up
// when we're done.
err := ioutil.WriteFile(certFile, testCert, 0644)
if err != nil {
log.Fatal(err)
}
err = ioutil.WriteFile(keyFile, testKey, 0644)
if err != nil {
os.Remove(certFile)
log.Fatal(err)
}
result := m.Run()
os.Remove(certFile)
os.Remove(keyFile)
os.Exit(result)
}
_, err := TLS(c) func TestSetupParseBasic(t *testing.T) {
c := setup.NewTestController(`tls ` + certFile + ` ` + keyFile + ``)
_, err := Setup(c)
if err != nil { if err != nil {
t.Errorf("Expected no errors, got: %v", err) t.Errorf("Expected no errors, got: %v", err)
} }
// Basic checks // Basic checks
if c.TLS.Certificate != "cert.pem" { if !c.TLS.Manual {
t.Errorf("Expected certificate arg to be 'cert.pem', was '%s'", c.TLS.Certificate) t.Error("Expected TLS Manual=true, but was false")
}
if c.TLS.Key != "key.pem" {
t.Errorf("Expected key arg to be 'key.pem', was '%s'", c.TLS.Key)
} }
if !c.TLS.Enabled { if !c.TLS.Enabled {
t.Error("Expected TLS Enabled=true, but was false") t.Error("Expected TLS Enabled=true, but was false")
...@@ -34,6 +56,9 @@ func TestTLSParseBasic(t *testing.T) { ...@@ -34,6 +56,9 @@ func TestTLSParseBasic(t *testing.T) {
// Cipher checks // Cipher checks
expectedCiphers := []uint16{ expectedCiphers := []uint16{
tls.TLS_FALLBACK_SCSV,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
...@@ -42,7 +67,6 @@ func TestTLSParseBasic(t *testing.T) { ...@@ -42,7 +67,6 @@ func TestTLSParseBasic(t *testing.T) {
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_RSA_WITH_AES_256_CBC_SHA, tls.TLS_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_RSA_WITH_AES_128_CBC_SHA, tls.TLS_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_FALLBACK_SCSV,
} }
// Ensure count is correct (plus one for TLS_FALLBACK_SCSV) // Ensure count is correct (plus one for TLS_FALLBACK_SCSV)
...@@ -63,23 +87,23 @@ func TestTLSParseBasic(t *testing.T) { ...@@ -63,23 +87,23 @@ func TestTLSParseBasic(t *testing.T) {
} }
} }
func TestTLSParseIncompleteParams(t *testing.T) { func TestSetupParseIncompleteParams(t *testing.T) {
// Using tls without args is an error because it's unnecessary. // Using tls without args is an error because it's unnecessary.
c := NewTestController(`tls`) c := setup.NewTestController(`tls`)
_, err := TLS(c) _, err := Setup(c)
if err == nil { if err == nil {
t.Error("Expected an error, but didn't get one") t.Error("Expected an error, but didn't get one")
} }
} }
func TestTLSParseWithOptionalParams(t *testing.T) { func TestSetupParseWithOptionalParams(t *testing.T) {
params := `tls cert.crt cert.key { params := `tls ` + certFile + ` ` + keyFile + ` {
protocols ssl3.0 tls1.2 protocols ssl3.0 tls1.2
ciphers RSA-3DES-EDE-CBC-SHA RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
}` }`
c := NewTestController(params) c := setup.NewTestController(params)
_, err := TLS(c) _, err := Setup(c)
if err != nil { if err != nil {
t.Errorf("Expected no errors, got: %v", err) t.Errorf("Expected no errors, got: %v", err)
} }
...@@ -97,13 +121,13 @@ func TestTLSParseWithOptionalParams(t *testing.T) { ...@@ -97,13 +121,13 @@ func TestTLSParseWithOptionalParams(t *testing.T) {
} }
} }
func TestTLSDefaultWithOptionalParams(t *testing.T) { func TestSetupDefaultWithOptionalParams(t *testing.T) {
params := `tls { params := `tls {
ciphers RSA-3DES-EDE-CBC-SHA ciphers RSA-3DES-EDE-CBC-SHA
}` }`
c := NewTestController(params) c := setup.NewTestController(params)
_, err := TLS(c) _, err := Setup(c)
if err != nil { if err != nil {
t.Errorf("Expected no errors, got: %v", err) t.Errorf("Expected no errors, got: %v", err)
} }
...@@ -113,7 +137,7 @@ func TestTLSDefaultWithOptionalParams(t *testing.T) { ...@@ -113,7 +137,7 @@ func TestTLSDefaultWithOptionalParams(t *testing.T) {
} }
// TODO: If we allow this... but probably not a good idea. // TODO: If we allow this... but probably not a good idea.
// func TestTLSDisableHTTPRedirect(t *testing.T) { // func TestSetupDisableHTTPRedirect(t *testing.T) {
// c := NewTestController(`tls { // c := NewTestController(`tls {
// allow_http // allow_http
// }`) // }`)
...@@ -126,34 +150,34 @@ func TestTLSDefaultWithOptionalParams(t *testing.T) { ...@@ -126,34 +150,34 @@ func TestTLSDefaultWithOptionalParams(t *testing.T) {
// } // }
// } // }
func TestTLSParseWithWrongOptionalParams(t *testing.T) { func TestSetupParseWithWrongOptionalParams(t *testing.T) {
// Test protocols wrong params // Test protocols wrong params
params := `tls cert.crt cert.key { params := `tls ` + certFile + ` ` + keyFile + ` {
protocols ssl tls protocols ssl tls
}` }`
c := NewTestController(params) c := setup.NewTestController(params)
_, err := TLS(c) _, err := Setup(c)
if err == nil { if err == nil {
t.Errorf("Expected errors, but no error returned") t.Errorf("Expected errors, but no error returned")
} }
// Test ciphers wrong params // Test ciphers wrong params
params = `tls cert.crt cert.key { params = `tls ` + certFile + ` ` + keyFile + ` {
ciphers not-valid-cipher ciphers not-valid-cipher
}` }`
c = NewTestController(params) c = setup.NewTestController(params)
_, err = TLS(c) _, err = Setup(c)
if err == nil { if err == nil {
t.Errorf("Expected errors, but no error returned") t.Errorf("Expected errors, but no error returned")
} }
} }
func TestTLSParseWithClientAuth(t *testing.T) { func TestSetupParseWithClientAuth(t *testing.T) {
params := `tls cert.crt cert.key { params := `tls ` + certFile + ` ` + keyFile + ` {
clients client_ca.crt client2_ca.crt clients client_ca.crt client2_ca.crt
}` }`
c := NewTestController(params) c := setup.NewTestController(params)
_, err := TLS(c) _, err := Setup(c)
if err != nil { if err != nil {
t.Errorf("Expected no errors, got: %v", err) t.Errorf("Expected no errors, got: %v", err)
} }
...@@ -169,12 +193,40 @@ func TestTLSParseWithClientAuth(t *testing.T) { ...@@ -169,12 +193,40 @@ func TestTLSParseWithClientAuth(t *testing.T) {
} }
// Test missing client cert file // Test missing client cert file
params = `tls cert.crt cert.key { params = `tls ` + certFile + ` ` + keyFile + ` {
clients clients
}` }`
c = NewTestController(params) c = setup.NewTestController(params)
_, err = TLS(c) _, err = Setup(c)
if err == nil { if err == nil {
t.Errorf("Expected an error, but no error returned") t.Errorf("Expected an error, but no error returned")
} }
} }
const (
certFile = "test_cert.pem"
keyFile = "test_key.pem"
)
var testCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBkjCCATmgAwIBAgIJANfFCBcABL6LMAkGByqGSM49BAEwFDESMBAGA1UEAxMJ
bG9jYWxob3N0MB4XDTE2MDIxMDIyMjAyNFoXDTE4MDIwOTIyMjAyNFowFDESMBAG
A1UEAxMJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEs22MtnG7
9K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDLSiVQvFZ6lUszTlczNxVk
pEfqrM6xAupB7qN1MHMwHQYDVR0OBBYEFHxYDvAxUwL4XrjPev6qZ/BiLDs5MEQG
A1UdIwQ9MDuAFHxYDvAxUwL4XrjPev6qZ/BiLDs5oRikFjAUMRIwEAYDVQQDEwls
b2NhbGhvc3SCCQDXxQgXAAS+izAMBgNVHRMEBTADAQH/MAkGByqGSM49BAEDSAAw
RQIgRvBqbyJM2JCJqhA1FmcoZjeMocmhxQHTt1c+1N2wFUgCIQDtvrivbBPA688N
Qh3sMeAKNKPsx5NxYdoWuu9KWcKz9A==
-----END CERTIFICATE-----
`)
var testKey = []byte(`-----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIGLtRmwzYVcrH3J0BnzYbGPdWVF10i9p6mxkA4+b2fURoAoGCCqGSM49
AwEHoUQDQgAEs22MtnG79K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDL
SiVQvFZ6lUszTlczNxVkpEfqrM6xAupB7g==
-----END EC PRIVATE KEY-----
`)
package letsencrypt package https
import ( import (
"path/filepath" "path/filepath"
......
package letsencrypt package https
import ( import (
"path/filepath" "path/filepath"
......
package letsencrypt package https
import ( import (
"bufio" "bufio"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/rsa"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
...@@ -20,7 +22,7 @@ import ( ...@@ -20,7 +22,7 @@ import (
type User struct { type User struct {
Email string Email string
Registration *acme.RegistrationResource Registration *acme.RegistrationResource
key *rsa.PrivateKey key crypto.PrivateKey
} }
// GetEmail gets u's email. // GetEmail gets u's email.
...@@ -34,14 +36,14 @@ func (u User) GetRegistration() *acme.RegistrationResource { ...@@ -34,14 +36,14 @@ func (u User) GetRegistration() *acme.RegistrationResource {
} }
// GetPrivateKey gets u's private key. // GetPrivateKey gets u's private key.
func (u User) GetPrivateKey() *rsa.PrivateKey { func (u User) GetPrivateKey() crypto.PrivateKey {
return u.key return u.key
} }
// getUser loads the user with the given email from disk. // getUser loads the user with the given email from disk.
// If the user does not exist, it will create a new one, // If the user does not exist, it will create a new one,
// but it does NOT save new users to the disk or register // but it does NOT save new users to the disk or register
// them via ACME. // them via ACME. It does NOT prompt the user.
func getUser(email string) (User, error) { func getUser(email string) (User, error) {
var user User var user User
...@@ -63,7 +65,7 @@ func getUser(email string) (User, error) { ...@@ -63,7 +65,7 @@ func getUser(email string) (User, error) {
} }
// load their private key // load their private key
user.key, err = loadRSAPrivateKey(storage.UserKeyFile(email)) user.key, err = loadPrivateKey(storage.UserKeyFile(email))
if err != nil { if err != nil {
return user, err return user, err
} }
...@@ -72,7 +74,8 @@ func getUser(email string) (User, error) { ...@@ -72,7 +74,8 @@ func getUser(email string) (User, error) {
} }
// saveUser persists a user's key and account registration // saveUser persists a user's key and account registration
// to the file system. It does NOT register the user via ACME. // to the file system. It does NOT register the user via ACME
// or prompt the user.
func saveUser(user User) error { func saveUser(user User) error {
// make user account folder // make user account folder
err := os.MkdirAll(storage.User(user.Email), 0700) err := os.MkdirAll(storage.User(user.Email), 0700)
...@@ -81,7 +84,7 @@ func saveUser(user User) error { ...@@ -81,7 +84,7 @@ func saveUser(user User) error {
} }
// save private key file // save private key file
err = saveRSAPrivateKey(user.key, storage.UserKeyFile(user.Email)) err = savePrivateKey(user.key, storage.UserKeyFile(user.Email))
if err != nil { if err != nil {
return err return err
} }
...@@ -99,10 +102,10 @@ func saveUser(user User) error { ...@@ -99,10 +102,10 @@ func saveUser(user User) error {
// with a new private key. This function does NOT save the // with a new private key. This function does NOT save the
// user to disk or register it via ACME. If you want to use // user to disk or register it via ACME. If you want to use
// a user account that might already exist, call getUser // a user account that might already exist, call getUser
// instead. // instead. It does NOT prompt the user.
func newUser(email string) (User, error) { func newUser(email string) (User, error) {
user := User{Email: email} user := User{Email: email}
privateKey, err := rsa.GenerateKey(rand.Reader, rsaKeySizeToUse) privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil { if err != nil {
return user, errors.New("error generating private key: " + err.Error()) return user, errors.New("error generating private key: " + err.Error())
} }
...@@ -114,10 +117,10 @@ func newUser(email string) (User, error) { ...@@ -114,10 +117,10 @@ func newUser(email string) (User, error) {
// address from the user to use for TLS for cfg. If it // address from the user to use for TLS for cfg. If it
// cannot get an email address, it returns empty string. // cannot get an email address, it returns empty string.
// (It will warn the user of the consequences of an // (It will warn the user of the consequences of an
// empty email.) If skipPrompt is true, the user will // empty email.) This function MAY prompt the user for
// NOT be prompted and an empty email will be returned // input. If userPresent is false, the operator will
// instead. // NOT be prompted and an empty email may be returned.
func getEmail(cfg server.Config, skipPrompt bool) string { func getEmail(cfg server.Config, userPresent bool) string {
// First try the tls directive from the Caddyfile // First try the tls directive from the Caddyfile
leEmail := cfg.TLS.LetsEncryptEmail leEmail := cfg.TLS.LetsEncryptEmail
if leEmail == "" { if leEmail == "" {
...@@ -135,11 +138,12 @@ func getEmail(cfg server.Config, skipPrompt bool) string { ...@@ -135,11 +138,12 @@ func getEmail(cfg server.Config, skipPrompt bool) string {
} }
if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) { if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) {
leEmail = dir.Name() leEmail = dir.Name()
DefaultEmail = leEmail // save for next time
} }
} }
} }
} }
if leEmail == "" && !skipPrompt { if leEmail == "" && userPresent {
// Alas, we must bother the user and ask for an email address; // Alas, we must bother the user and ask for an email address;
// if they proceed they also agree to the SA. // if they proceed they also agree to the SA.
reader := bufio.NewReader(stdin) reader := bufio.NewReader(stdin)
...@@ -154,10 +158,11 @@ func getEmail(cfg server.Config, skipPrompt bool) string { ...@@ -154,10 +158,11 @@ func getEmail(cfg server.Config, skipPrompt bool) string {
if err != nil { if err != nil {
return "" return ""
} }
leEmail = strings.TrimSpace(leEmail)
DefaultEmail = leEmail DefaultEmail = leEmail
Agreed = true Agreed = true
} }
return strings.TrimSpace(leEmail) return leEmail
} }
// promptUserAgreement prompts the user to agree to the agreement // promptUserAgreement prompts the user to agree to the agreement
......
package letsencrypt package https
import ( import (
"bytes" "bytes"
...@@ -114,7 +114,7 @@ func TestGetUserAlreadyExists(t *testing.T) { ...@@ -114,7 +114,7 @@ func TestGetUserAlreadyExists(t *testing.T) {
} }
// Assert keys are the same // Assert keys are the same
if !rsaPrivateKeysSame(user.key, user2.key) { if !PrivateKeysSame(user.key, user2.key) {
t.Error("Expected private key to be the same after loading, but it wasn't") t.Error("Expected private key to be the same after loading, but it wasn't")
} }
...@@ -140,13 +140,13 @@ func TestGetEmail(t *testing.T) { ...@@ -140,13 +140,13 @@ func TestGetEmail(t *testing.T) {
LetsEncryptEmail: "test1@foo.com", LetsEncryptEmail: "test1@foo.com",
}, },
} }
actual := getEmail(config, false) actual := getEmail(config, true)
if actual != "test1@foo.com" { if actual != "test1@foo.com" {
t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", "test1@foo.com", actual) t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", "test1@foo.com", actual)
} }
// Test2: Use default email from flag (or user previously typing it) // Test2: Use default email from flag (or user previously typing it)
actual = getEmail(server.Config{}, false) actual = getEmail(server.Config{}, true)
if actual != DefaultEmail { if actual != DefaultEmail {
t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", DefaultEmail, actual) t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", DefaultEmail, actual)
} }
...@@ -158,7 +158,7 @@ func TestGetEmail(t *testing.T) { ...@@ -158,7 +158,7 @@ func TestGetEmail(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Could not simulate user input, error: %v", err) t.Fatalf("Could not simulate user input, error: %v", err)
} }
actual = getEmail(server.Config{}, false) actual = getEmail(server.Config{}, true)
if actual != "test3@foo.com" { if actual != "test3@foo.com" {
t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual) t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual)
} }
...@@ -189,7 +189,7 @@ func TestGetEmail(t *testing.T) { ...@@ -189,7 +189,7 @@ func TestGetEmail(t *testing.T) {
t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err) t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
} }
} }
actual = getEmail(server.Config{}, false) actual = getEmail(server.Config{}, true)
if actual != "test4-3@foo.com" { if actual != "test4-3@foo.com" {
t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual) t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual)
} }
......
package letsencrypt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"io/ioutil"
"os"
)
// loadRSAPrivateKey loads a PEM-encoded RSA private key from file.
func loadRSAPrivateKey(file string) (*rsa.PrivateKey, error) {
keyBytes, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
keyBlock, _ := pem.Decode(keyBytes)
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
}
// saveRSAPrivateKey saves a PEM-encoded RSA private key to file.
func saveRSAPrivateKey(key *rsa.PrivateKey, file string) error {
pemKey := pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}
keyOut, err := os.Create(file)
if err != nil {
return err
}
keyOut.Chmod(0600)
defer keyOut.Close()
return pem.Encode(keyOut, &pemKey)
}
package letsencrypt
import (
"encoding/json"
"io/ioutil"
"log"
"time"
"github.com/mholt/caddy/server"
"github.com/xenolf/lego/acme"
)
// OnChange is a callback function that will be used to restart
// the application or the part of the application that uses
// the certificates maintained by this package. When at least
// one certificate is renewed or an OCSP status changes, this
// function will be called.
var OnChange func() error
// maintainAssets is a permanently-blocking function
// that loops indefinitely and, on a regular schedule, checks
// certificates for expiration and initiates a renewal of certs
// that are expiring soon. It also updates OCSP stapling and
// performs other maintenance of assets.
//
// You must pass in the server configs to maintain and the channel
// which you'll close when maintenance should stop, to allow this
// goroutine to clean up after itself and unblock.
func maintainAssets(configs []server.Config, stopChan chan struct{}) {
renewalTicker := time.NewTicker(RenewInterval)
ocspTicker := time.NewTicker(OCSPInterval)
for {
select {
case <-renewalTicker.C:
n, errs := renewCertificates(configs, true)
if len(errs) > 0 {
for _, err := range errs {
log.Printf("[ERROR] Certificate renewal: %v", err)
}
}
// even if there was an error, some renewals may have succeeded
if n > 0 && OnChange != nil {
err := OnChange()
if err != nil {
log.Printf("[ERROR] OnChange after cert renewal: %v", err)
}
}
case <-ocspTicker.C:
for bundle, oldResp := range ocspCache {
// start checking OCSP staple about halfway through validity period for good measure
refreshTime := oldResp.ThisUpdate.Add(oldResp.NextUpdate.Sub(oldResp.ThisUpdate) / 2)
// only check for updated OCSP validity window if refreshTime is in the past
if time.Now().After(refreshTime) {
_, newResp, err := acme.GetOCSPForCert(*bundle)
if err != nil {
log.Printf("[ERROR] Checking OCSP for bundle: %v", err)
continue
}
// we're not looking for different status, just a more future expiration
if newResp.NextUpdate != oldResp.NextUpdate {
if OnChange != nil {
log.Printf("[INFO] Updating OCSP stapling to extend validity period to %v", newResp.NextUpdate)
err := OnChange()
if err != nil {
log.Printf("[ERROR] OnChange after OCSP trigger: %v", err)
}
break
}
}
}
}
case <-stopChan:
renewalTicker.Stop()
ocspTicker.Stop()
return
}
}
}
// renewCertificates loops through all configured site and
// looks for certificates to renew. Nothing is mutated
// through this function; all changes happen directly on disk.
// It returns the number of certificates renewed and any errors
// that occurred. It only performs a renewal if necessary.
// If useCustomPort is true, a custom port will be used, and
// whatever is listening at 443 better proxy ACME requests to it.
// Otherwise, the acme package will create its own listener on 443.
func renewCertificates(configs []server.Config, useCustomPort bool) (int, []error) {
log.Printf("[INFO] Checking certificates for %d hosts", len(configs))
var errs []error
var n int
for _, cfg := range configs {
// Host must be TLS-enabled and have existing assets managed by LE
if !cfg.TLS.Enabled || !existingCertAndKey(cfg.Host) {
continue
}
// Read the certificate and get the NotAfter time.
certBytes, err := ioutil.ReadFile(storage.SiteCertFile(cfg.Host))
if err != nil {
errs = append(errs, err)
continue // still have to check other certificates
}
expTime, err := acme.GetPEMCertExpiration(certBytes)
if err != nil {
errs = append(errs, err)
continue
}
// The time returned from the certificate is always in UTC.
// So calculate the time left with local time as UTC.
// Directly convert it to days for the following checks.
daysLeft := int(expTime.Sub(time.Now().UTC()).Hours() / 24)
// Renew if getting close to expiration.
if daysLeft <= renewDaysBefore {
log.Printf("[INFO] Certificate for %s has %d days remaining; attempting renewal", cfg.Host, daysLeft)
var client *acme.Client
if useCustomPort {
client, err = newClientPort("", AlternatePort) // email not used for renewal
} else {
client, err = newClient("")
}
if err != nil {
errs = append(errs, err)
continue
}
// Read and set up cert meta, required for renewal
metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(cfg.Host))
if err != nil {
errs = append(errs, err)
continue
}
privBytes, err := ioutil.ReadFile(storage.SiteKeyFile(cfg.Host))
if err != nil {
errs = append(errs, err)
continue
}
var certMeta acme.CertificateResource
err = json.Unmarshal(metaBytes, &certMeta)
certMeta.Certificate = certBytes
certMeta.PrivateKey = privBytes
// Renew certificate
Renew:
newCertMeta, err := client.RenewCertificate(certMeta, true)
if err != nil {
if _, ok := err.(acme.TOSError); ok {
err := client.AgreeToTOS()
if err != nil {
errs = append(errs, err)
}
goto Renew
}
time.Sleep(10 * time.Second)
newCertMeta, err = client.RenewCertificate(certMeta, true)
if err != nil {
errs = append(errs, err)
continue
}
}
saveCertResource(newCertMeta)
n++
} else if daysLeft <= renewDaysBefore+7 && daysLeft >= renewDaysBefore+6 {
log.Printf("[WARNING] Certificate for %s has %d days remaining; will automatically renew when %d days remain\n", cfg.Host, daysLeft, renewDaysBefore)
}
}
return n, errs
}
// renewDaysBefore is how many days before expiration to renew certificates.
const renewDaysBefore = 14
...@@ -311,19 +311,19 @@ func TestParseAll(t *testing.T) { ...@@ -311,19 +311,19 @@ func TestParseAll(t *testing.T) {
}}, }},
{`localhost:1234`, false, [][]address{ {`localhost:1234`, false, [][]address{
[]address{{"localhost:1234", "", "localhost", "1234"}}, {{"localhost:1234", "", "localhost", "1234"}},
}}, }},
{`localhost:1234 { {`localhost:1234 {
} }
localhost:2015 { localhost:2015 {
}`, false, [][]address{ }`, false, [][]address{
[]address{{"localhost:1234", "", "localhost", "1234"}}, {{"localhost:1234", "", "localhost", "1234"}},
[]address{{"localhost:2015", "", "localhost", "2015"}}, {{"localhost:2015", "", "localhost", "2015"}},
}}, }},
{`localhost:1234, http://host2`, false, [][]address{ {`localhost:1234, http://host2`, false, [][]address{
[]address{{"localhost:1234", "", "localhost", "1234"}, {"http://host2", "http", "host2", "80"}}, {{"localhost:1234", "", "localhost", "1234"}, {"http://host2", "http", "host2", "80"}},
}}, }},
{`localhost:1234, http://host2,`, true, [][]address{}}, {`localhost:1234, http://host2,`, true, [][]address{}},
...@@ -332,15 +332,15 @@ func TestParseAll(t *testing.T) { ...@@ -332,15 +332,15 @@ func TestParseAll(t *testing.T) {
} }
https://host3.com, https://host4.com { https://host3.com, https://host4.com {
}`, false, [][]address{ }`, false, [][]address{
[]address{{"http://host1.com", "http", "host1.com", "80"}, {"http://host2.com", "http", "host2.com", "80"}}, {{"http://host1.com", "http", "host1.com", "80"}, {"http://host2.com", "http", "host2.com", "80"}},
[]address{{"https://host3.com", "https", "host3.com", "443"}, {"https://host4.com", "https", "host4.com", "443"}}, {{"https://host3.com", "https", "host3.com", "443"}, {"https://host4.com", "https", "host4.com", "443"}},
}}, }},
{`import import_glob*.txt`, false, [][]address{ {`import import_glob*.txt`, false, [][]address{
[]address{{"glob0.host0", "", "glob0.host0", ""}}, {{"glob0.host0", "", "glob0.host0", ""}},
[]address{{"glob0.host1", "", "glob0.host1", ""}}, {{"glob0.host1", "", "glob0.host1", ""}},
[]address{{"glob1.host0", "", "glob1.host0", ""}}, {{"glob1.host0", "", "glob1.host0", ""}},
[]address{{"glob2.host0", "", "glob2.host0", ""}}, {{"glob2.host0", "", "glob2.host0", ""}},
}}, }},
} { } {
p := testParser(test.input) p := testParser(test.input)
......
...@@ -8,11 +8,13 @@ import ( ...@@ -8,11 +8,13 @@ import (
"errors" "errors"
"io/ioutil" "io/ioutil"
"log" "log"
"net"
"os" "os"
"os/exec" "os/exec"
"path" "path"
"sync/atomic"
"github.com/mholt/caddy/caddy/letsencrypt" "github.com/mholt/caddy/caddy/https"
) )
func init() { func init() {
...@@ -55,8 +57,9 @@ func Restart(newCaddyfile Input) error { ...@@ -55,8 +57,9 @@ func Restart(newCaddyfile Input) error {
// Prepare our payload to the child process // Prepare our payload to the child process
cdyfileGob := caddyfileGob{ cdyfileGob := caddyfileGob{
ListenerFds: make(map[string]uintptr), ListenerFds: make(map[string]uintptr),
Caddyfile: newCaddyfile, Caddyfile: newCaddyfile,
OnDemandTLSCertsIssued: atomic.LoadInt32(https.OnDemandIssuedCount),
} }
// Prepare a pipe to the fork's stdin so it can get the Caddyfile // Prepare a pipe to the fork's stdin so it can get the Caddyfile
...@@ -133,13 +136,28 @@ func getCertsForNewCaddyfile(newCaddyfile Input) error { ...@@ -133,13 +136,28 @@ func getCertsForNewCaddyfile(newCaddyfile Input) error {
} }
// first mark the configs that are qualified for managed TLS // first mark the configs that are qualified for managed TLS
letsencrypt.MarkQualified(configs) https.MarkQualified(configs)
// we must make sure port is set before we group by bind address // since we group by bind address to obtain certs, we must call
letsencrypt.EnableTLS(configs) // EnableTLS to make sure the port is set properly first
// (can ignore error since we aren't actually using the certs)
https.EnableTLS(configs, false)
// find out if we can let the acme package start its own challenge listener
// on port 80
var proxyACME bool
serversMu.Lock()
for _, s := range servers {
_, port, _ := net.SplitHostPort(s.Addr)
if port == "80" {
proxyACME = true
break
}
}
serversMu.Unlock()
// place certs on the disk // place certs on the disk
err = letsencrypt.ObtainCerts(configs, letsencrypt.AlternatePort) err = https.ObtainCerts(configs, false, proxyACME)
if err != nil { if err != nil {
return errors.New("obtaining certs: " + err.Error()) return errors.New("obtaining certs: " + err.Error())
} }
......
...@@ -118,7 +118,7 @@ md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61` ...@@ -118,7 +118,7 @@ md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61`
} }
if !actualRule.Password(pwd) || actualRule.Password(test.password+"!") { if !actualRule.Password(pwd) || actualRule.Password(test.password+"!") {
t.Errorf("Test %d, rule %d: Expected password '%v', got '%v'", t.Errorf("Test %d, rule %d: Expected password '%v', got '%v'",
i, j, test.password, actualRule.Password) i, j, test.password, actualRule.Password(""))
} }
expectedRes := fmt.Sprintf("%v", expectedRule.Resources) expectedRes := fmt.Sprintf("%v", expectedRule.Resources)
......
...@@ -41,7 +41,7 @@ func TestBrowse(t *testing.T) { ...@@ -41,7 +41,7 @@ func TestBrowse(t *testing.T) {
// test case #2 tests detectaction of custom template // test case #2 tests detectaction of custom template
{"browse . " + tempTemplatePath, []string{"."}, false}, {"browse . " + tempTemplatePath, []string{"."}, false},
// test case #3 tests detection of non-existant template // test case #3 tests detection of non-existent template
{"browse . " + nonExistantDirPath, nil, true}, {"browse . " + nonExistantDirPath, nil, true},
// test case #4 tests detection of duplicate pathscopes // test case #4 tests detection of duplicate pathscopes
......
...@@ -12,7 +12,7 @@ import ( ...@@ -12,7 +12,7 @@ import (
"github.com/mholt/caddy/middleware/errors" "github.com/mholt/caddy/middleware/errors"
) )
// Errors configures a new gzip middleware instance. // Errors configures a new errors middleware instance.
func Errors(c *Controller) (middleware.Middleware, error) { func Errors(c *Controller) (middleware.Middleware, error) {
handler, err := errorsParse(c) handler, err := errorsParse(c)
if err != nil { if err != nil {
......
...@@ -14,34 +14,34 @@ func TestRedir(t *testing.T) { ...@@ -14,34 +14,34 @@ func TestRedir(t *testing.T) {
expectedRules []redirect.Rule expectedRules []redirect.Rule
}{ }{
// test case #0 tests the recognition of a valid HTTP status code defined outside of block statement // test case #0 tests the recognition of a valid HTTP status code defined outside of block statement
{"redir 300 {\n/ /foo\n}", false, []redirect.Rule{redirect.Rule{FromPath: "/", To: "/foo", Code: 300}}}, {"redir 300 {\n/ /foo\n}", false, []redirect.Rule{{FromPath: "/", To: "/foo", Code: 300}}},
// test case #1 tests the recognition of an invalid HTTP status code defined outside of block statement // test case #1 tests the recognition of an invalid HTTP status code defined outside of block statement
{"redir 9000 {\n/ /foo\n}", true, []redirect.Rule{redirect.Rule{}}}, {"redir 9000 {\n/ /foo\n}", true, []redirect.Rule{{}}},
// test case #2 tests the detection of a valid HTTP status code outside of a block statement being overriden by an invalid HTTP status code inside statement of a block statement // test case #2 tests the detection of a valid HTTP status code outside of a block statement being overriden by an invalid HTTP status code inside statement of a block statement
{"redir 300 {\n/ /foo 9000\n}", true, []redirect.Rule{redirect.Rule{}}}, {"redir 300 {\n/ /foo 9000\n}", true, []redirect.Rule{{}}},
// test case #3 tests the detection of an invalid HTTP status code outside of a block statement being overriden by a valid HTTP status code inside statement of a block statement // test case #3 tests the detection of an invalid HTTP status code outside of a block statement being overriden by a valid HTTP status code inside statement of a block statement
{"redir 9000 {\n/ /foo 300\n}", true, []redirect.Rule{redirect.Rule{}}}, {"redir 9000 {\n/ /foo 300\n}", true, []redirect.Rule{{}}},
// test case #4 tests the recognition of a TO redirection in a block statement.The HTTP status code is set to the default of 301 - MovedPermanently // test case #4 tests the recognition of a TO redirection in a block statement.The HTTP status code is set to the default of 301 - MovedPermanently
{"redir 302 {\n/foo\n}", false, []redirect.Rule{redirect.Rule{FromPath: "/", To: "/foo", Code: 302}}}, {"redir 302 {\n/foo\n}", false, []redirect.Rule{{FromPath: "/", To: "/foo", Code: 302}}},
// test case #5 tests the recognition of a TO and From redirection in a block statement // test case #5 tests the recognition of a TO and From redirection in a block statement
{"redir {\n/bar /foo 303\n}", false, []redirect.Rule{redirect.Rule{FromPath: "/bar", To: "/foo", Code: 303}}}, {"redir {\n/bar /foo 303\n}", false, []redirect.Rule{{FromPath: "/bar", To: "/foo", Code: 303}}},
// test case #6 tests the recognition of a TO redirection in a non-block statement. The HTTP status code is set to the default of 301 - MovedPermanently // test case #6 tests the recognition of a TO redirection in a non-block statement. The HTTP status code is set to the default of 301 - MovedPermanently
{"redir /foo", false, []redirect.Rule{redirect.Rule{FromPath: "/", To: "/foo", Code: 301}}}, {"redir /foo", false, []redirect.Rule{{FromPath: "/", To: "/foo", Code: 301}}},
// test case #7 tests the recognition of a TO and From redirection in a non-block statement // test case #7 tests the recognition of a TO and From redirection in a non-block statement
{"redir /bar /foo 303", false, []redirect.Rule{redirect.Rule{FromPath: "/bar", To: "/foo", Code: 303}}}, {"redir /bar /foo 303", false, []redirect.Rule{{FromPath: "/bar", To: "/foo", Code: 303}}},
// test case #8 tests the recognition of multiple redirections // test case #8 tests the recognition of multiple redirections
{"redir {\n / /foo 304 \n} \n redir {\n /bar /foobar 305 \n}", false, []redirect.Rule{redirect.Rule{FromPath: "/", To: "/foo", Code: 304}, redirect.Rule{FromPath: "/bar", To: "/foobar", Code: 305}}}, {"redir {\n / /foo 304 \n} \n redir {\n /bar /foobar 305 \n}", false, []redirect.Rule{{FromPath: "/", To: "/foo", Code: 304}, {FromPath: "/bar", To: "/foobar", Code: 305}}},
// test case #9 tests the detection of duplicate redirections // test case #9 tests the detection of duplicate redirections
{"redir {\n /bar /foo 304 \n} redir {\n /bar /foo 304 \n}", true, []redirect.Rule{redirect.Rule{}}}, {"redir {\n /bar /foo 304 \n} redir {\n /bar /foo 304 \n}", true, []redirect.Rule{{}}},
} { } {
recievedFunc, err := Redir(NewTestController(test.input)) recievedFunc, err := Redir(NewTestController(test.input))
if err != nil && !test.shouldErr { if err != nil && !test.shouldErr {
......
...@@ -80,8 +80,8 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { ...@@ -80,8 +80,8 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
status, _ = strconv.Atoi(c.Val()) status, _ = strconv.Atoi(c.Val())
if status < 400 || status > 499 { if status < 200 || (status > 299 && status < 400) || status > 499 {
return nil, c.Err("status must be 4xx") return nil, c.Err("status must be 2xx or 4xx")
} }
default: default:
return nil, c.ArgErr() return nil, c.ArgErr()
......
...@@ -135,24 +135,45 @@ func TestRewriteParse(t *testing.T) { ...@@ -135,24 +135,45 @@ func TestRewriteParse(t *testing.T) {
to /to to /to
if {path} is a if {path} is a
}`, false, []rewrite.Rule{ }`, false, []rewrite.Rule{
&rewrite.ComplexRule{Base: "/", To: "/to", Ifs: []rewrite.If{rewrite.If{A: "{path}", Operator: "is", B: "a"}}}, &rewrite.ComplexRule{Base: "/", To: "/to", Ifs: []rewrite.If{{A: "{path}", Operator: "is", B: "a"}}},
}},
{`rewrite {
status 500
}`, true, []rewrite.Rule{
&rewrite.ComplexRule{},
}}, }},
{`rewrite { {`rewrite {
status 400 status 400
}`, false, []rewrite.Rule{ }`, false, []rewrite.Rule{
&rewrite.ComplexRule{Base: "/", Regexp: regexp.MustCompile(".*"), Status: 400}, &rewrite.ComplexRule{Base: "/", Status: 400},
}}, }},
{`rewrite { {`rewrite {
to /to to /to
status 400 status 400
}`, false, []rewrite.Rule{ }`, false, []rewrite.Rule{
&rewrite.ComplexRule{Base: "/", To: "/to", Regexp: regexp.MustCompile(".*"), Status: 400}, &rewrite.ComplexRule{Base: "/", To: "/to", Status: 400},
}}, }},
{`rewrite { {`rewrite {
status 399 status 399
}`, true, []rewrite.Rule{ }`, true, []rewrite.Rule{
&rewrite.ComplexRule{}, &rewrite.ComplexRule{},
}}, }},
{`rewrite {
status 200
}`, false, []rewrite.Rule{
&rewrite.ComplexRule{Base: "/", Status: 200},
}},
{`rewrite {
to /to
status 200
}`, false, []rewrite.Rule{
&rewrite.ComplexRule{Base: "/", To: "/to", Status: 200},
}},
{`rewrite {
status 199
}`, true, []rewrite.Rule{
&rewrite.ComplexRule{},
}},
{`rewrite { {`rewrite {
status 0 status 0
}`, true, []rewrite.Rule{ }`, true, []rewrite.Rule{
......
...@@ -37,7 +37,7 @@ func TestStartup(t *testing.T) { ...@@ -37,7 +37,7 @@ func TestStartup(t *testing.T) {
// test case #1 tests proper functionality of non-blocking commands // test case #1 tests proper functionality of non-blocking commands
{"startup mkdir " + osSenitiveTestDir + " &", false, true}, {"startup mkdir " + osSenitiveTestDir + " &", false, true},
// test case #2 tests handling of non-existant commands // test case #2 tests handling of non-existent commands
{"startup " + strconv.Itoa(int(time.Now().UnixNano())), true, true}, {"startup " + strconv.Itoa(int(time.Now().UnixNano())), true, true},
} }
......
CHANGES CHANGES
0.8.2 (February 25, 2016)
- On-demand TLS can obtain certificates during handshakes
- Built with Go 1.6
- Process log (-log) is rotated when it gets large
- Managed certificates get renewed 30 days early instead of just 14
- fastcgi: Allow scheme prefix before address
- markdown: Support for definition lists
- proxy: Allow proxy to insecure HTTPS backends
- proxy: Support proxy to unix socket
- rewrite: Status code can be 2xx or 4xx
- templates: New .Markdown action to interpret included file as Markdown
- templates: .Truncate now truncates from end of string when length is negative
- tls: Set hard limit for certificates obtained with on-demand TLS
- tls: Load certificates from directory
- tls: Add SHA384 cipher suites
- Multiple bug fixes and internal changes
0.8.1 (January 12, 2016) 0.8.1 (January 12, 2016)
- Improved OCSP stapling - Improved OCSP stapling
- Better graceful reload when new hosts need certificates from Let's Encrypt - Better graceful reload when new hosts need certificates from Let's Encrypt
...@@ -14,6 +32,7 @@ CHANGES ...@@ -14,6 +32,7 @@ CHANGES
- tls: No longer allow HTTPS over port 80 - tls: No longer allow HTTPS over port 80
- Dozens of bug fixes, improvements, and more tests across the board - Dozens of bug fixes, improvements, and more tests across the board
0.8.0 (December 4, 2015) 0.8.0 (December 4, 2015)
- HTTPS by default via Let's Encrypt (certs & keys are fully managed) - HTTPS by default via Let's Encrypt (certs & keys are fully managed)
- Graceful restarts (on POSIX-compliant systems) - Graceful restarts (on POSIX-compliant systems)
......
CADDY 0.8.1 CADDY 0.8.2
Website Website
https://caddyserver.com https://caddyserver.com
Twitter
@caddyserver @caddyserver
Source Code Source Code
https://github.com/mholt/caddy https://github.com/mholt/caddy
https://github.com/caddyserver
For instructions on using Caddy, please see the user guide on the website. For instructions on using Caddy, please see the user guide on the website.
For a list of what's new in this version, see CHANGES.txt. For a list of what's new in this version, see CHANGES.txt.
Please consider donating to the project if you think it is helpful,
especially if your company is using Caddy. There are also sponsorship
opportunities available!
If you have a question, bug report, or would like to contribute, please open an If you have a question, bug report, or would like to contribute, please open an
issue or submit a pull request on GitHub. Your contributions do not go unnoticed! issue or submit a pull request on GitHub. Your contributions do not go unnoticed!
......
...@@ -13,33 +13,22 @@ import ( ...@@ -13,33 +13,22 @@ import (
"time" "time"
"github.com/mholt/caddy/caddy" "github.com/mholt/caddy/caddy"
"github.com/mholt/caddy/caddy/letsencrypt" "github.com/mholt/caddy/caddy/https"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acme"
) "gopkg.in/natefinch/lumberjack.v2"
var (
conf string
cpu string
logfile string
revoke string
version bool
)
const (
appName = "Caddy"
appVersion = "0.8.1"
) )
func init() { func init() {
caddy.TrapSignals() caddy.TrapSignals()
flag.BoolVar(&letsencrypt.Agreed, "agree", false, "Agree to Let's Encrypt Subscriber Agreement") setVersion()
flag.StringVar(&letsencrypt.CAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "Certificate authority ACME server") flag.BoolVar(&https.Agreed, "agree", false, "Agree to Let's Encrypt Subscriber Agreement")
flag.StringVar(&https.CAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "Certificate authority ACME server")
flag.StringVar(&conf, "conf", "", "Configuration file to use (default="+caddy.DefaultConfigFile+")") flag.StringVar(&conf, "conf", "", "Configuration file to use (default="+caddy.DefaultConfigFile+")")
flag.StringVar(&cpu, "cpu", "100%", "CPU cap") flag.StringVar(&cpu, "cpu", "100%", "CPU cap")
flag.StringVar(&letsencrypt.DefaultEmail, "email", "", "Default Let's Encrypt account email address") flag.StringVar(&https.DefaultEmail, "email", "", "Default Let's Encrypt account email address")
flag.DurationVar(&caddy.GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown") flag.DurationVar(&caddy.GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown")
flag.StringVar(&caddy.Host, "host", caddy.DefaultHost, "Default host") flag.StringVar(&caddy.Host, "host", caddy.DefaultHost, "Default host")
flag.BoolVar(&caddy.HTTP2, "http2", true, "HTTP/2 support") // TODO: temporary flag until http2 merged into std lib flag.BoolVar(&caddy.HTTP2, "http2", true, "Use HTTP/2")
flag.StringVar(&logfile, "log", "", "Process log file") flag.StringVar(&logfile, "log", "", "Process log file")
flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file") flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file")
flag.StringVar(&caddy.Port, "port", caddy.DefaultPort, "Default port") flag.StringVar(&caddy.Port, "port", caddy.DefaultPort, "Default port")
...@@ -65,15 +54,16 @@ func main() { ...@@ -65,15 +54,16 @@ func main() {
case "": case "":
log.SetOutput(ioutil.Discard) log.SetOutput(ioutil.Discard)
default: default:
file, err := os.OpenFile(logfile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) log.SetOutput(&lumberjack.Logger{
if err != nil { Filename: logfile,
log.Fatalf("Error opening process log file: %v", err) MaxSize: 100,
} MaxAge: 14,
log.SetOutput(file) MaxBackups: 10,
})
} }
if revoke != "" { if revoke != "" {
err := letsencrypt.Revoke(revoke) err := https.Revoke(revoke)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
...@@ -81,7 +71,10 @@ func main() { ...@@ -81,7 +71,10 @@ func main() {
os.Exit(0) os.Exit(0)
} }
if version { if version {
fmt.Printf("%s %s\n", caddy.AppName, caddy.AppVersion) fmt.Printf("%s %s\n", appName, appVersion)
if devBuild && gitShortStat != "" {
fmt.Printf("%s\n%s\n", gitShortStat, gitFilesModified)
}
os.Exit(0) os.Exit(0)
} }
...@@ -197,3 +190,44 @@ func setCPU(cpu string) error { ...@@ -197,3 +190,44 @@ func setCPU(cpu string) error {
runtime.GOMAXPROCS(numCPU) runtime.GOMAXPROCS(numCPU)
return nil return nil
} }
// setVersion figures out the version information based on
// variables set by -ldflags.
func setVersion() {
// A development build is one that's not at a tag or has uncommitted changes
devBuild = gitTag == "" || gitShortStat != ""
// Only set the appVersion if -ldflags was used
if gitNearestTag != "" || gitTag != "" {
if devBuild && gitNearestTag != "" {
appVersion = fmt.Sprintf("%s (+%s %s)",
strings.TrimPrefix(gitNearestTag, "v"), gitCommit, buildDate)
} else if gitTag != "" {
appVersion = strings.TrimPrefix(gitTag, "v")
}
}
}
const appName = "Caddy"
// Flags that control program flow or startup
var (
conf string
cpu string
logfile string
revoke string
version bool
)
// Build information obtained with the help of -ldflags
var (
appVersion = "(untracked dev build)" // inferred at startup
devBuild = true // inferred at startup
buildDate string // date -u
gitTag string // git describe --exact-match HEAD 2> /dev/null
gitNearestTag string // git describe --abbrev=0 --tags HEAD
gitCommit string // git rev-parse HEAD
gitShortStat string // git diff-index --shortstat
gitFilesModified string // git diff-index --name-only HEAD
)
...@@ -42,3 +42,34 @@ func TestSetCPU(t *testing.T) { ...@@ -42,3 +42,34 @@ func TestSetCPU(t *testing.T) {
runtime.GOMAXPROCS(currentCPU) runtime.GOMAXPROCS(currentCPU)
} }
} }
func TestSetVersion(t *testing.T) {
setVersion()
if !devBuild {
t.Error("Expected default to assume development build, but it didn't")
}
if got, want := appVersion, "(untracked dev build)"; got != want {
t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
}
gitTag = "v1.1"
setVersion()
if devBuild {
t.Error("Expected a stable build if gitTag is set with no changes")
}
if got, want := appVersion, "1.1"; got != want {
t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
}
gitTag = ""
gitNearestTag = "v1.0"
gitCommit = "deadbeef"
buildDate = "Fri Feb 26 06:53:17 UTC 2016"
setVersion()
if !devBuild {
t.Error("Expected inferring a dev build when gitTag is empty")
}
if got, want := appVersion, "1.0 (+deadbeef Fri Feb 26 06:53:17 UTC 2016)"; got != want {
t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
}
}
...@@ -139,7 +139,7 @@ md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61` ...@@ -139,7 +139,7 @@ md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61`
if rule.Password, err = GetHtpasswdMatcher(filename, rule.Username, siteRoot); err != nil { if rule.Password, err = GetHtpasswdMatcher(filename, rule.Username, siteRoot); err != nil {
t.Fatalf("GetHtpasswdMatcher(%q, %q): %v", htfh.Name(), rule.Username, err) t.Fatalf("GetHtpasswdMatcher(%q, %q): %v", htfh.Name(), rule.Username, err)
} }
t.Logf("%d. username=%q password=%v", i, rule.Username, rule.Password) t.Logf("%d. username=%q", i, rule.Username)
if !rule.Password(htpasswdPasswd) || rule.Password(htpasswdPasswd+"!") { if !rule.Password(htpasswdPasswd) || rule.Password(htpasswdPasswd+"!") {
t.Errorf("%d (%s) password does not match.", i, rule.Username) t.Errorf("%d (%s) password does not match.", i, rule.Username)
} }
......
...@@ -9,6 +9,8 @@ import ( ...@@ -9,6 +9,8 @@ import (
"strings" "strings"
"text/template" "text/template"
"time" "time"
"github.com/russross/blackfriday"
) )
// This file contains the context and functions available for // This file contains the context and functions available for
...@@ -130,10 +132,16 @@ func (c Context) PathMatches(pattern string) bool { ...@@ -130,10 +132,16 @@ func (c Context) PathMatches(pattern string) bool {
return Path(c.Req.URL.Path).Matches(pattern) return Path(c.Req.URL.Path).Matches(pattern)
} }
// Truncate truncates the input string to the given length. If // Truncate truncates the input string to the given length.
// input is shorter than length, the entire string is returned. // If length is negative, it returns that many characters
// starting from the end of the string. If the absolute value
// of length is greater than len(input), the whole input is
// returned.
func (c Context) Truncate(input string, length int) string { func (c Context) Truncate(input string, length int) string {
if len(input) > length { if length < 0 && len(input)+length > 0 {
return input[len(input)+length:]
}
if length >= 0 && len(input) > length {
return input[:length] return input[:length]
} }
return input return input
...@@ -190,3 +198,17 @@ func (c Context) StripExt(path string) string { ...@@ -190,3 +198,17 @@ func (c Context) StripExt(path string) string {
func (c Context) Replace(input, find, replacement string) string { func (c Context) Replace(input, find, replacement string) string {
return strings.Replace(input, find, replacement, -1) return strings.Replace(input, find, replacement, -1)
} }
// Markdown returns the HTML contents of the markdown contained in filename
// (relative to the site root).
func (c Context) Markdown(filename string) (string, error) {
body, err := c.Include(filename)
if err != nil {
return "", err
}
renderer := blackfriday.HtmlRenderer(0, "", "")
extns := blackfriday.EXTENSION_TABLES | blackfriday.EXTENSION_FENCED_CODE | blackfriday.EXTENSION_STRIKETHROUGH | blackfriday.EXTENSION_DEFINITION_LISTS
markdown := blackfriday.Markdown([]byte(body), renderer, extns)
return string(markdown), nil
}
...@@ -92,6 +92,45 @@ func TestIncludeNotExisting(t *testing.T) { ...@@ -92,6 +92,45 @@ func TestIncludeNotExisting(t *testing.T) {
} }
} }
func TestMarkdown(t *testing.T) {
context := getContextOrFail(t)
inputFilename := "test_file"
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
defer func() {
err := os.Remove(absInFilePath)
if err != nil && !os.IsNotExist(err) {
t.Fatalf("Failed to clean test file!")
}
}()
tests := []struct {
fileContent string
expectedContent string
}{
// Test 0 - test parsing of markdown
{
fileContent: "* str1\n* str2\n",
expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n",
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
// WriteFile truncates the contentt
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
if err != nil {
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
}
content, _ := context.Markdown(inputFilename)
if content != test.expectedContent {
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
}
}
}
func TestCookie(t *testing.T) { func TestCookie(t *testing.T) {
tests := []struct { tests := []struct {
...@@ -420,12 +459,36 @@ func TestTruncate(t *testing.T) { ...@@ -420,12 +459,36 @@ func TestTruncate(t *testing.T) {
inputLength: 10, inputLength: 10,
expected: "string", expected: "string",
}, },
// Test 3 - zero length
{
inputString: "string",
inputLength: 0,
expected: "",
},
// Test 4 - negative, smaller length
{
inputString: "string",
inputLength: -5,
expected: "tring",
},
// Test 5 - negative, exact length
{
inputString: "string",
inputLength: -6,
expected: "string",
},
// Test 6 - negative, bigger length
{
inputString: "string",
inputLength: -7,
expected: "string",
},
} }
for i, test := range tests { for i, test := range tests {
actual := context.Truncate(test.inputString, test.inputLength) actual := context.Truncate(test.inputString, test.inputLength)
if actual != test.expected { if actual != test.expected {
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength) t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength)
} }
} }
} }
......
...@@ -34,6 +34,7 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er ...@@ -34,6 +34,7 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er
if h.Debug { if h.Debug {
// Write error to response instead of to log // Write error to response instead of to log
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(status) w.WriteHeader(status)
fmt.Fprintln(w, errMsg) fmt.Fprintln(w, errMsg)
return 0, err // returning < 400 signals that a response has been written return 0, err // returning < 400 signals that a response has been written
...@@ -124,6 +125,7 @@ func (h ErrorHandler) recovery(w http.ResponseWriter, r *http.Request) { ...@@ -124,6 +125,7 @@ func (h ErrorHandler) recovery(w http.ResponseWriter, r *http.Request) {
// Write error and stack trace to the response rather than to a log // Write error and stack trace to the response rather than to a log
var stackBuf [4096]byte var stackBuf [4096]byte
stack := stackBuf[:runtime.Stack(stackBuf[:], false)] stack := stackBuf[:runtime.Stack(stackBuf[:], false)]
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "%s\n\n%s", panicMsg, stack) fmt.Fprintf(w, "%s\n\n%s", panicMsg, stack)
} else { } else {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"testing" "testing"
...@@ -158,7 +159,10 @@ func TestVisibleErrorWithPanic(t *testing.T) { ...@@ -158,7 +159,10 @@ func TestVisibleErrorWithPanic(t *testing.T) {
func genErrorHandler(status int, err error, body string) middleware.Handler { func genErrorHandler(status int, err error, body string) middleware.Handler {
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprint(w, body) if len(body) > 0 {
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
fmt.Fprint(w, body)
}
return status, err return status, err
}) })
} }
...@@ -71,7 +71,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -71,7 +71,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// Connect to FastCGI gateway // Connect to FastCGI gateway
network, address := rule.parseAddress() network, address := rule.parseAddress()
fcgi, err := Dial(network, address) fcgiBackend, err := Dial(network, address)
if err != nil { if err != nil {
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
...@@ -80,13 +80,13 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -80,13 +80,13 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length")) contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length"))
switch r.Method { switch r.Method {
case "HEAD": case "HEAD":
resp, err = fcgi.Head(env) resp, err = fcgiBackend.Head(env)
case "GET": case "GET":
resp, err = fcgi.Get(env) resp, err = fcgiBackend.Get(env)
case "OPTIONS": case "OPTIONS":
resp, err = fcgi.Options(env) resp, err = fcgiBackend.Options(env)
default: default:
resp, err = fcgi.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength) resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
} }
if resp.Body != nil { if resp.Body != nil {
...@@ -97,24 +97,28 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -97,24 +97,28 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
// Write response header
writeHeader(w, resp) writeHeader(w, resp)
// Write the response body // Write the response body
// TODO: If this has an error, the response will already be
// partly written. We should copy out of resp.Body into a buffer
// first, then write it to the response...
_, err = io.Copy(w, resp.Body) _, err = io.Copy(w, resp.Body)
if err != nil { if err != nil {
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
// FastCGI stderr outputs // Log any stderr output from upstream
if fcgi.stderr.Len() != 0 { if fcgiBackend.stderr.Len() != 0 {
// Remove trailing newline, error logger already does this. // Remove trailing newline, error logger already does this.
err = LogError(strings.TrimSuffix(fcgi.stderr.String(), "\n")) err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
} }
return resp.StatusCode, err // Normally we would return the status code if it is an error status (>= 400),
// however, upstream FastCGI apps don't know about our contract and have
// probably already written an error page. So we just return 0, indicating
// that the response body is already written. However, we do return any
// error value so it can be logged.
// Note that the proxy middleware works the same way, returning status=0.
return 0, err
} }
} }
...@@ -130,7 +134,7 @@ func (r Rule) parseAddress() (string, string) { ...@@ -130,7 +134,7 @@ func (r Rule) parseAddress() (string, string) {
if strings.HasPrefix(r.Address, "tcp://") { if strings.HasPrefix(r.Address, "tcp://") {
return "tcp", r.Address[len("tcp://"):] return "tcp", r.Address[len("tcp://"):]
} }
// check if address has fastcgi scheme explicity set // check if address has fastcgi scheme explicitly set
if strings.HasPrefix(r.Address, "fastcgi://") { if strings.HasPrefix(r.Address, "fastcgi://") {
return "tcp", r.Address[len("fastcgi://"):] return "tcp", r.Address[len("fastcgi://"):]
} }
...@@ -174,7 +178,7 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string] ...@@ -174,7 +178,7 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
// Separate remote IP and port; more lenient than net.SplitHostPort // Separate remote IP and port; more lenient than net.SplitHostPort
var ip, port string var ip, port string
if idx := strings.Index(r.RemoteAddr, ":"); idx > -1 { if idx := strings.LastIndex(r.RemoteAddr, ":"); idx > -1 {
ip = r.RemoteAddr[:idx] ip = r.RemoteAddr[:idx]
port = r.RemoteAddr[idx+1:] port = r.RemoteAddr[idx+1:]
} else { } else {
......
package fastcgi package fastcgi
import ( import (
"net"
"net/http"
"net/http/fcgi"
"net/http/httptest"
"net/url"
"strconv"
"testing" "testing"
) )
func TestRuleParseAddress(t *testing.T) { func TestServeHTTP(t *testing.T) {
body := "This is some test body content"
bodyLenStr := strconv.Itoa(len(body))
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to create listener for test: %v", err)
}
defer listener.Close()
go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", bodyLenStr)
w.Write([]byte(body))
}))
handler := Handler{
Next: nil,
Rules: []Rule{{Path: "/", Address: listener.Addr().String()}},
}
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Unable to create request: %v", err)
}
w := httptest.NewRecorder()
status, err := handler.ServeHTTP(w, r)
if got, want := status, 0; got != want {
t.Errorf("Expected returned status code to be %d, got %d", want, got)
}
if err != nil {
t.Errorf("Expected nil error, got: %v", err)
}
if got, want := w.Header().Get("Content-Length"), bodyLenStr; got != want {
t.Errorf("Expected Content-Length to be '%s', got: '%s'", want, got)
}
if got, want := w.Body.String(), body; got != want {
t.Errorf("Expected response body to be '%s', got: '%s'", want, got)
}
}
func TestRuleParseAddress(t *testing.T) {
getClientTestTable := []struct { getClientTestTable := []struct {
rule *Rule rule *Rule
expectednetwork string expectednetwork string
...@@ -25,7 +70,61 @@ func TestRuleParseAddress(t *testing.T) { ...@@ -25,7 +70,61 @@ func TestRuleParseAddress(t *testing.T) {
if _, actualaddress := entry.rule.parseAddress(); actualaddress != entry.expectedaddress { if _, actualaddress := entry.rule.parseAddress(); actualaddress != entry.expectedaddress {
t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address, actualaddress, entry.expectedaddress) t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address, actualaddress, entry.expectedaddress)
} }
}
}
func TestBuildEnv(t *testing.T) {
testBuildEnv := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string) {
var h Handler
env, err := h.buildEnv(r, rule, fpath)
if err != nil {
t.Error("Unexpected error:", err.Error())
}
for k, v := range envExpected {
if env[k] != v {
t.Errorf("Unexpected %v. Got %v, expected %v", k, env[k], v)
}
}
} }
rule := Rule{}
url, err := url.Parse("http://localhost:2015/fgci_test.php?test=blabla")
if err != nil {
t.Error("Unexpected error:", err.Error())
}
r := http.Request{
Method: "GET",
URL: url,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Host: "localhost:2015",
RemoteAddr: "[2b02:1810:4f2d:9400:70ab:f822:be8a:9093]:51688",
RequestURI: "/fgci_test.php",
}
fpath := "/fgci_test.php"
var envExpected = map[string]string{
"REMOTE_ADDR": "[2b02:1810:4f2d:9400:70ab:f822:be8a:9093]",
"REMOTE_PORT": "51688",
"SERVER_PROTOCOL": "HTTP/1.1",
"QUERY_STRING": "test=blabla",
"REQUEST_METHOD": "GET",
"HTTP_HOST": "localhost:2015",
}
// 1. Test for full canonical IPv6 address
testBuildEnv(&r, rule, fpath, envExpected)
// 2. Test for shorthand notation of IPv6 address
r.RemoteAddr = "[::1]:51688"
envExpected["REMOTE_ADDR"] = "[::1]"
testBuildEnv(&r, rule, fpath, envExpected)
// 3. Test for IPv4 address
r.RemoteAddr = "192.168.0.10:51688"
envExpected["REMOTE_ADDR"] = "192.168.0.10"
testBuildEnv(&r, rule, fpath, envExpected)
} }
...@@ -169,12 +169,11 @@ type FCGIClient struct { ...@@ -169,12 +169,11 @@ type FCGIClient struct {
reqID uint16 reqID uint16
} }
// Dial connects to the fcgi responder at the specified network address. // DialWithDialer connects to the fcgi responder at the specified network address, using custom net.Dialer.
// See func net.Dial for a description of the network and address parameters. // See func net.Dial for a description of the network and address parameters.
func Dial(network, address string) (fcgi *FCGIClient, err error) { func DialWithDialer(network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) {
var conn net.Conn var conn net.Conn
conn, err = dialer.Dial(network, address)
conn, err = net.Dial(network, address)
if err != nil { if err != nil {
return return
} }
...@@ -188,6 +187,12 @@ func Dial(network, address string) (fcgi *FCGIClient, err error) { ...@@ -188,6 +187,12 @@ func Dial(network, address string) (fcgi *FCGIClient, err error) {
return return
} }
// Dial connects to the fcgi responder at the specified network address, using default net.Dialer.
// See func net.Dial for a description of the network and address parameters.
func Dial(network, address string) (fcgi *FCGIClient, err error) {
return DialWithDialer(network, address, net.Dialer{})
}
// Close closes fcgi connnection // Close closes fcgi connnection
func (c *FCGIClient) Close() { func (c *FCGIClient) Close() {
c.rwc.Close() c.rwc.Close()
......
...@@ -39,9 +39,7 @@ const ( ...@@ -39,9 +39,7 @@ const (
ipPort = "127.0.0.1:59000" ipPort = "127.0.0.1:59000"
) )
var ( var globalt *testing.T
t_ *testing.T
)
type FastCGIServer struct{} type FastCGIServer struct{}
...@@ -158,7 +156,7 @@ func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[ ...@@ -158,7 +156,7 @@ func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
if bytes.Index(content, []byte("FAILED")) >= 0 { if bytes.Index(content, []byte("FAILED")) >= 0 {
t_.Error("Server return failed message") globalt.Error("Server return failed message")
} }
return return
...@@ -193,7 +191,7 @@ func generateRandFile(size int) (p string, m string) { ...@@ -193,7 +191,7 @@ func generateRandFile(size int) (p string, m string) {
func DisabledTest(t *testing.T) { func DisabledTest(t *testing.T) {
// TODO: test chunked reader // TODO: test chunked reader
t_ = t globalt = t
rand.Seed(time.Now().UTC().UnixNano()) rand.Seed(time.Now().UTC().UnixNano())
......
...@@ -45,7 +45,7 @@ func TestServeHTTP(t *testing.T) { ...@@ -45,7 +45,7 @@ func TestServeHTTP(t *testing.T) {
expectedStatus int expectedStatus int
expectedBodyContent string expectedBodyContent string
}{ }{
// Test 0 - access withoutt any path // Test 0 - access without any path
{ {
url: "https://foo", url: "https://foo",
expectedStatus: http.StatusNotFound, expectedStatus: http.StatusNotFound,
...@@ -78,7 +78,7 @@ func TestServeHTTP(t *testing.T) { ...@@ -78,7 +78,7 @@ func TestServeHTTP(t *testing.T) {
url: "https://foo/dir/", url: "https://foo/dir/",
expectedStatus: http.StatusNotFound, expectedStatus: http.StatusNotFound,
}, },
// Test 6 - access folder withtout trailing slash // Test 6 - access folder without trailing slash
{ {
url: "https://foo/dir", url: "https://foo/dir",
expectedStatus: http.StatusMovedPermanently, expectedStatus: http.StatusMovedPermanently,
......
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
package gzip package gzip
import ( import (
"bufio"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"strings" "strings"
...@@ -130,3 +132,12 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { ...@@ -130,3 +132,12 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) {
n, err := w.Writer.Write(b) n, err := w.Writer.Write(b)
return n, err return n, err
} }
// Hijack implements http.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, fmt.Errorf("not a Hijacker")
}
...@@ -32,6 +32,18 @@ func TestMarkdown(t *testing.T) { ...@@ -32,6 +32,18 @@ func TestMarkdown(t *testing.T) {
StaticDir: DefaultStaticDir, StaticDir: DefaultStaticDir,
StaticFiles: make(map[string]string), StaticFiles: make(map[string]string),
}, },
{
Renderer: blackfriday.HtmlRenderer(0, "", ""),
PathScope: "/docflags",
Extensions: []string{".md"},
Styles: []string{},
Scripts: []string{},
Templates: map[string]string{
DefaultTemplate: "testdata/docflags/template.txt",
},
StaticDir: DefaultStaticDir,
StaticFiles: make(map[string]string),
},
{ {
Renderer: blackfriday.HtmlRenderer(0, "", ""), Renderer: blackfriday.HtmlRenderer(0, "", ""),
PathScope: "/log", PathScope: "/log",
...@@ -114,6 +126,26 @@ Welcome to A Caddy website! ...@@ -114,6 +126,26 @@ Welcome to A Caddy website!
t.Fatalf("Expected body: %v got: %v", expectedBody, respBody) t.Fatalf("Expected body: %v got: %v", expectedBody, respBody)
} }
req, err = http.NewRequest("GET", "/docflags/test.md", nil)
if err != nil {
t.Fatalf("Could not create HTTP request: %v", err)
}
rec = httptest.NewRecorder()
md.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("Wrong status, expected: %d and got %d", http.StatusOK, rec.Code)
}
respBody = rec.Body.String()
expectedBody = `Doc.var_string hello
Doc.var_bool <no value>
DocFlags.var_string <no value>
DocFlags.var_bool true`
if !equalStrings(respBody, expectedBody) {
t.Fatalf("Expected body: %v got: %v", expectedBody, respBody)
}
req, err = http.NewRequest("GET", "/log/test.md", nil) req, err = http.NewRequest("GET", "/log/test.md", nil)
if err != nil { if err != nil {
t.Fatalf("Could not create HTTP request: %v", err) t.Fatalf("Could not create HTTP request: %v", err)
...@@ -190,6 +222,7 @@ Welcome to title! ...@@ -190,6 +222,7 @@ Welcome to title!
expectedLinks := []string{ expectedLinks := []string{
"/blog/test.md", "/blog/test.md",
"/docflags/test.md",
"/log/test.md", "/log/test.md",
} }
......
...@@ -23,6 +23,9 @@ type Metadata struct { ...@@ -23,6 +23,9 @@ type Metadata struct {
// Variables to be used with Template // Variables to be used with Template
Variables map[string]string Variables map[string]string
// Flags to be used with Template
Flags map[string]bool
} }
// load loads parsed values in parsedMap into Metadata // load loads parsed values in parsedMap into Metadata
...@@ -40,8 +43,11 @@ func (m *Metadata) load(parsedMap map[string]interface{}) { ...@@ -40,8 +43,11 @@ func (m *Metadata) load(parsedMap map[string]interface{}) {
} }
// store everything as a variable // store everything as a variable
for key, val := range parsedMap { for key, val := range parsedMap {
if v, ok := val.(string); ok { switch v := val.(type) {
case string:
m.Variables[key] = v m.Variables[key] = v
case bool:
m.Flags[key] = v
} }
} }
} }
...@@ -219,11 +225,18 @@ func findParser(b []byte) MetadataParser { ...@@ -219,11 +225,18 @@ func findParser(b []byte) MetadataParser {
return nil return nil
} }
func newMetadata() Metadata {
return Metadata{
Variables: make(map[string]string),
Flags: make(map[string]bool),
}
}
// parsers returns all available parsers // parsers returns all available parsers
func parsers() []MetadataParser { func parsers() []MetadataParser {
return []MetadataParser{ return []MetadataParser{
&JSONMetadataParser{metadata: Metadata{Variables: make(map[string]string)}}, &JSONMetadataParser{metadata: newMetadata()},
&TOMLMetadataParser{metadata: Metadata{Variables: make(map[string]string)}}, &TOMLMetadataParser{metadata: newMetadata()},
&YAMLMetadataParser{metadata: Metadata{Variables: make(map[string]string)}}, &YAMLMetadataParser{metadata: newMetadata()},
} }
} }
...@@ -18,11 +18,15 @@ var TOML = [5]string{` ...@@ -18,11 +18,15 @@ var TOML = [5]string{`
title = "A title" title = "A title"
template = "default" template = "default"
name = "value" name = "value"
positive = true
negative = false
`, `,
`+++ `+++
title = "A title" title = "A title"
template = "default" template = "default"
name = "value" name = "value"
positive = true
negative = false
+++ +++
Page content Page content
`, `,
...@@ -30,12 +34,16 @@ Page content ...@@ -30,12 +34,16 @@ Page content
title = "A title" title = "A title"
template = "default" template = "default"
name = "value" name = "value"
positive = true
negative = false
`, `,
`title = "A title" template = "default" [variables] name = "value"`, `title = "A title" template = "default" [variables] name = "value"`,
`+++ `+++
title = "A title" title = "A title"
template = "default" template = "default"
name = "value" name = "value"
positive = true
negative = false
+++ +++
`, `,
} }
...@@ -44,11 +52,15 @@ var YAML = [5]string{` ...@@ -44,11 +52,15 @@ var YAML = [5]string{`
title : A title title : A title
template : default template : default
name : value name : value
positive : true
negative : false
`, `,
`--- `---
title : A title title : A title
template : default template : default
name : value name : value
positive : true
negative : false
--- ---
Page content Page content
`, `,
...@@ -57,11 +69,13 @@ title : A title ...@@ -57,11 +69,13 @@ title : A title
template : default template : default
name : value name : value
`, `,
`title : A title template : default variables : name : value`, `title : A title template : default variables : name : value : positive : true : negative : false`,
`--- `---
title : A title title : A title
template : default template : default
name : value name : value
positive : true
negative : false
--- ---
`, `,
} }
...@@ -69,12 +83,16 @@ name : value ...@@ -69,12 +83,16 @@ name : value
var JSON = [5]string{` var JSON = [5]string{`
"title" : "A title", "title" : "A title",
"template" : "default", "template" : "default",
"name" : "value" "name" : "value",
"positive" : true,
"negative" : false
`, `,
`{ `{
"title" : "A title", "title" : "A title",
"template" : "default", "template" : "default",
"name" : "value" "name" : "value",
"positive" : true,
"negative" : false
} }
Page content Page content
`, `,
...@@ -82,19 +100,25 @@ Page content ...@@ -82,19 +100,25 @@ Page content
{ {
"title" : "A title", "title" : "A title",
"template" : "default", "template" : "default",
"name" : "value" "name" : "value",
"positive" : true,
"negative" : false
`, `,
` `
{ {
"title" :: "A title", "title" :: "A title",
"template" : "default", "template" : "default",
"name" : "value" "name" : "value",
"positive" : true,
"negative" : false
} }
`, `,
`{ `{
"title" : "A title", "title" : "A title",
"template" : "default", "template" : "default",
"name" : "value" "name" : "value",
"positive" : true,
"negative" : false
} }
`, `,
} }
...@@ -108,6 +132,10 @@ func TestParsers(t *testing.T) { ...@@ -108,6 +132,10 @@ func TestParsers(t *testing.T) {
"title": "A title", "title": "A title",
"template": "default", "template": "default",
}, },
Flags: map[string]bool{
"positive": true,
"negative": false,
},
} }
compare := func(m Metadata) bool { compare := func(m Metadata) bool {
if m.Title != expected.Title { if m.Title != expected.Title {
...@@ -121,7 +149,14 @@ func TestParsers(t *testing.T) { ...@@ -121,7 +149,14 @@ func TestParsers(t *testing.T) {
return false return false
} }
} }
return len(m.Variables) == len(expected.Variables) for k, v := range m.Flags {
if v != expected.Flags[k] {
return false
}
}
varLenOK := len(m.Variables) == len(expected.Variables)
flagLenOK := len(m.Flags) == len(expected.Flags)
return varLenOK && flagLenOK
} }
data := []struct { data := []struct {
...@@ -129,9 +164,9 @@ func TestParsers(t *testing.T) { ...@@ -129,9 +164,9 @@ func TestParsers(t *testing.T) {
testData [5]string testData [5]string
name string name string
}{ }{
{&JSONMetadataParser{metadata: Metadata{Variables: make(map[string]string)}}, JSON, "json"}, {&JSONMetadataParser{metadata: newMetadata()}, JSON, "json"},
{&YAMLMetadataParser{metadata: Metadata{Variables: make(map[string]string)}}, YAML, "yaml"}, {&YAMLMetadataParser{metadata: newMetadata()}, YAML, "yaml"},
{&TOMLMetadataParser{metadata: Metadata{Variables: make(map[string]string)}}, TOML, "toml"}, {&TOMLMetadataParser{metadata: newMetadata()}, TOML, "toml"},
} }
for _, v := range data { for _, v := range data {
...@@ -207,9 +242,9 @@ Mycket olika byggnader har man i de nordiska rikena: pyramidformiga, kilformiga, ...@@ -207,9 +242,9 @@ Mycket olika byggnader har man i de nordiska rikena: pyramidformiga, kilformiga,
testData string testData string
name string name string
}{ }{
{&JSONMetadataParser{metadata: Metadata{Variables: make(map[string]string)}}, JSON, "json"}, {&JSONMetadataParser{metadata: newMetadata()}, JSON, "json"},
{&YAMLMetadataParser{metadata: Metadata{Variables: make(map[string]string)}}, YAML, "yaml"}, {&YAMLMetadataParser{metadata: newMetadata()}, YAML, "yaml"},
{&TOMLMetadataParser{metadata: Metadata{Variables: make(map[string]string)}}, TOML, "toml"}, {&TOMLMetadataParser{metadata: newMetadata()}, TOML, "toml"},
} }
for _, v := range data { for _, v := range data {
// metadata without identifiers // metadata without identifiers
......
...@@ -23,14 +23,15 @@ const ( ...@@ -23,14 +23,15 @@ const (
// Data represents a markdown document. // Data represents a markdown document.
type Data struct { type Data struct {
middleware.Context middleware.Context
Doc map[string]string Doc map[string]string
Links []PageLink DocFlags map[string]bool
Links []PageLink
} }
// Process processes the contents of a page in b. It parses the metadata // Process processes the contents of a page in b. It parses the metadata
// (if any) and uses the template (if found). // (if any) and uses the template (if found).
func (md Markdown) Process(c *Config, requestPath string, b []byte, ctx middleware.Context) ([]byte, error) { func (md Markdown) Process(c *Config, requestPath string, b []byte, ctx middleware.Context) ([]byte, error) {
var metadata = Metadata{Variables: make(map[string]string)} var metadata = newMetadata()
var markdown []byte var markdown []byte
var err error var err error
...@@ -68,7 +69,7 @@ func (md Markdown) Process(c *Config, requestPath string, b []byte, ctx middlewa ...@@ -68,7 +69,7 @@ func (md Markdown) Process(c *Config, requestPath string, b []byte, ctx middlewa
} }
// process markdown // process markdown
extns := blackfriday.EXTENSION_TABLES | blackfriday.EXTENSION_FENCED_CODE | blackfriday.EXTENSION_STRIKETHROUGH extns := blackfriday.EXTENSION_TABLES | blackfriday.EXTENSION_FENCED_CODE | blackfriday.EXTENSION_STRIKETHROUGH | blackfriday.EXTENSION_DEFINITION_LISTS
markdown = blackfriday.Markdown(markdown, c.Renderer, extns) markdown = blackfriday.Markdown(markdown, c.Renderer, extns)
// set it as body for template // set it as body for template
...@@ -100,9 +101,10 @@ func (md Markdown) processTemplate(c *Config, requestPath string, tmpl []byte, m ...@@ -100,9 +101,10 @@ func (md Markdown) processTemplate(c *Config, requestPath string, tmpl []byte, m
return nil, err return nil, err
} }
mdData := Data{ mdData := Data{
Context: ctx, Context: ctx,
Doc: metadata.Variables, Doc: metadata.Variables,
Links: c.Links, DocFlags: metadata.Flags,
Links: c.Links,
} }
c.RLock() c.RLock()
......
Doc.var_string {{.Doc.var_string}}
Doc.var_bool {{.Doc.var_bool}}
DocFlags.var_string {{.DocFlags.var_string}}
DocFlags.var_bool {{.DocFlags.var_bool}}
---
var_string: hello
var_bool: true
---
...@@ -13,30 +13,24 @@ type ( ...@@ -13,30 +13,24 @@ type (
// passed the next Handler in the chain. // passed the next Handler in the chain.
Middleware func(Handler) Handler Middleware func(Handler) Handler
// Handler is like http.Handler except ServeHTTP returns a status code // Handler is like http.Handler except ServeHTTP may return a status
// and an error. The status code is for the client's benefit; the error // code and/or error.
// value is for the server's benefit. The status code will be sent to
// the client while the error value will be logged privately. Sometimes,
// an error status code (4xx or 5xx) may be returned with a nil error
// when there is no reason to log the error on the server.
// //
// If a HandlerFunc returns an error (status >= 400), it should NOT // If ServeHTTP writes to the response body, it should return a status
// write to the response. This philosophy makes middleware.Handler // code of 0. This signals to other handlers above it that the response
// different from http.Handler: error handling should happen at the // body is already written, and that they should not write to it also.
// application layer or in dedicated error-handling middleware only
// rather than with an "every middleware for itself" paradigm.
// //
// The application or error-handling middleware should incorporate logic // If ServeHTTP encounters an error, it should return the error value
// to ensure that the client always gets a proper response according to // so it can be logged by designated error-handling middleware.
// the status code. For security reasons, it should probably not reveal
// the actual error message. (Instead it should be logged, for example.)
// //
// Handlers which do write to the response should return a status value // If writing a response after calling another ServeHTTP method, the
// < 400 as a signal that a response has been written. In other words, // returned status code SHOULD be used when writing the response.
// only error-handling middleware or the application will write to the //
// response for a status code >= 400. When ANY handler writes to the // If handling errors after calling another ServeHTTP method, the
// response, it should return a status code < 400 to signal others to // returned error value SHOULD be logged or handled accordingly.
// NOT write to the response again, which would be erroneous. //
// Otherwise, return values should be propagated down the middleware
// chain by returning them unchanged.
Handler interface { Handler interface {
ServeHTTP(http.ResponseWriter, *http.Request) (int, error) ServeHTTP(http.ResponseWriter, *http.Request) (int, error)
} }
...@@ -102,7 +96,8 @@ func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) { ...@@ -102,7 +96,8 @@ func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) {
w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat)) w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat))
} }
// currentTime returns time.Now() everytime it's called. It's used for mocking in tests. // currentTime, as it is defined here, returns time.Now().
// It's defined as a variable for mocking time in tests.
var currentTime = func() time.Time { var currentTime = func() time.Time {
return time.Now() return time.Now()
} }
...@@ -63,7 +63,6 @@ var tryDuration = 60 * time.Second ...@@ -63,7 +63,6 @@ var tryDuration = 60 * time.Second
// ServeHTTP satisfies the middleware.Handler interface. // ServeHTTP satisfies the middleware.Handler interface.
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, upstream := range p.Upstreams { for _, upstream := range p.Upstreams {
if middleware.Path(r.URL.Path).Matches(upstream.From()) && upstream.IsAllowedPath(r.URL.Path) { if middleware.Path(r.URL.Path).Matches(upstream.From()) && upstream.IsAllowedPath(r.URL.Path) {
var replacer middleware.Replacer var replacer middleware.Replacer
......
...@@ -3,6 +3,7 @@ package proxy ...@@ -3,6 +3,7 @@ package proxy
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
...@@ -11,6 +12,8 @@ import ( ...@@ -11,6 +12,8 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os" "os"
"path/filepath"
"runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
...@@ -160,6 +163,69 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { ...@@ -160,6 +163,69 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
} }
} }
func TestUnixSocketProxy(t *testing.T) {
if runtime.GOOS == "windows" {
return
}
trialMsg := "Is it working?"
var proxySuccess bool
// This is our fake "application" we want to proxy to
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Request was proxied when this is called
proxySuccess = true
fmt.Fprint(w, trialMsg)
}))
// Get absolute path for unix: socket
socketPath, err := filepath.Abs("./test_socket")
if err != nil {
t.Fatalf("Unable to get absolute path: %v", err)
}
// Change httptest.Server listener to listen to unix: socket
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Unable to listen: %v", err)
}
ts.Listener = ln
ts.Start()
defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
res, err := http.Get(echoProxy.URL)
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
actualMsg := fmt.Sprintf("%s", greeting)
if !proxySuccess {
t.Errorf("Expected request to be proxied, but it wasn't")
}
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream { func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name) uri, _ := url.Parse(name)
u := &fakeUpstream{ u := &fakeUpstream{
......
...@@ -59,6 +59,18 @@ func singleJoiningSlash(a, b string) string { ...@@ -59,6 +59,18 @@ func singleJoiningSlash(a, b string) string {
return a + b return a + b
} }
// Though the relevant directive prefix is just "unix:", url.Parse
// will - assuming the regular URL scheme - add additional slashes
// as if "unix" was a request protocol.
// What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):])
}
}
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites // NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
// URLs to the scheme, host, and base path provided in target. If the // URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir", // target's path is "/base" and the incoming request was for "/dir",
...@@ -68,8 +80,15 @@ func singleJoiningSlash(a, b string) string { ...@@ -68,8 +80,15 @@ func singleJoiningSlash(a, b string) string {
func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy { func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
targetQuery := target.RawQuery targetQuery := target.RawQuery
director := func(req *http.Request) { director := func(req *http.Request) {
req.URL.Scheme = target.Scheme if target.Scheme == "unix" {
req.URL.Host = target.Host // to make Dial work with unix URL,
// scheme and host have to be faked
req.URL.Scheme = "http"
req.URL.Host = "socket"
} else {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
}
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" { if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery req.URL.RawQuery = targetQuery + req.URL.RawQuery
...@@ -80,7 +99,13 @@ func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy { ...@@ -80,7 +99,13 @@ func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
req.URL.Path = strings.TrimPrefix(req.URL.Path, without) req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
} }
} }
return &ReverseProxy{Director: director} rp := &ReverseProxy{Director: director}
if target.Scheme == "unix" {
rp.Transport = &http.Transport{
Dial: socketDial(target.String()),
}
}
return rp
} }
func copyHeader(dst, src http.Header) { func copyHeader(dst, src http.Header) {
......
...@@ -65,7 +65,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { ...@@ -65,7 +65,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
upstream.Hosts = make([]*UpstreamHost, len(to)) upstream.Hosts = make([]*UpstreamHost, len(to))
for i, host := range to { for i, host := range to {
if !strings.HasPrefix(host, "http") { if !strings.HasPrefix(host, "http") &&
!strings.HasPrefix(host, "unix:") {
host = "http://" + host host = "http://" + host
} }
uh := &UpstreamHost{ uh := &UpstreamHost{
......
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
// to be written, however, in which case 200 must be assumed. // to be written, however, in which case 200 must be assumed.
// It is best to have the constructor initialize this type // It is best to have the constructor initialize this type
// with that default status code. // with that default status code.
type responseRecorder struct { type ResponseRecorder struct {
http.ResponseWriter http.ResponseWriter
status int status int
size int size int
...@@ -27,8 +27,8 @@ type responseRecorder struct { ...@@ -27,8 +27,8 @@ type responseRecorder struct {
// Because a status is not set unless WriteHeader is called // Because a status is not set unless WriteHeader is called
// explicitly, this constructor initializes with a status code // explicitly, this constructor initializes with a status code
// of 200 to cover the default case. // of 200 to cover the default case.
func NewResponseRecorder(w http.ResponseWriter) *responseRecorder { func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder {
return &responseRecorder{ return &ResponseRecorder{
ResponseWriter: w, ResponseWriter: w,
status: http.StatusOK, status: http.StatusOK,
start: time.Now(), start: time.Now(),
...@@ -37,14 +37,14 @@ func NewResponseRecorder(w http.ResponseWriter) *responseRecorder { ...@@ -37,14 +37,14 @@ func NewResponseRecorder(w http.ResponseWriter) *responseRecorder {
// WriteHeader records the status code and calls the // WriteHeader records the status code and calls the
// underlying ResponseWriter's WriteHeader method. // underlying ResponseWriter's WriteHeader method.
func (r *responseRecorder) WriteHeader(status int) { func (r *ResponseRecorder) WriteHeader(status int) {
r.status = status r.status = status
r.ResponseWriter.WriteHeader(status) r.ResponseWriter.WriteHeader(status)
} }
// Write is a wrapper that records the size of the body // Write is a wrapper that records the size of the body
// that gets written. // that gets written.
func (r *responseRecorder) Write(buf []byte) (int, error) { func (r *ResponseRecorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf) n, err := r.ResponseWriter.Write(buf)
if err == nil { if err == nil {
r.size += n r.size += n
...@@ -52,11 +52,21 @@ func (r *responseRecorder) Write(buf []byte) (int, error) { ...@@ -52,11 +52,21 @@ func (r *responseRecorder) Write(buf []byte) (int, error) {
return n, err return n, err
} }
// Hijacker is a wrapper of http.Hijacker underearth if any, // Size is a Getter to size property
// otherwise it just returns an error. func (r *ResponseRecorder) Size() int {
func (r *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return r.size
}
// Status is a Getter to status property
func (r *ResponseRecorder) Status() int {
return r.status
}
// Hijack implements http.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (r *ResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := r.ResponseWriter.(http.Hijacker); ok { if hj, ok := r.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack() return hj.Hijack()
} }
return nil, nil, errors.New("I'm not a Hijacker") return nil, nil, errors.New("not a Hijacker")
} }
...@@ -30,7 +30,7 @@ type replacer struct { ...@@ -30,7 +30,7 @@ type replacer struct {
// values into the replacer. rr may be nil if it is not // values into the replacer. rr may be nil if it is not
// available. emptyValue should be the string that is used // available. emptyValue should be the string that is used
// in place of empty string (can still be empty string). // in place of empty string (can still be empty string).
func NewReplacer(r *http.Request, rr *responseRecorder, emptyValue string) Replacer { func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer {
rep := replacer{ rep := replacer{
replacements: map[string]string{ replacements: map[string]string{
"{method}": r.Method, "{method}": r.Method,
......
...@@ -26,8 +26,8 @@ type Config struct { ...@@ -26,8 +26,8 @@ type Config struct {
// HTTPS configuration // HTTPS configuration
TLS TLSConfig TLS TLSConfig
// Middleware stack; map of path scope to middleware -- TODO: Support path scope? // Middleware stack
Middleware map[string][]middleware.Middleware Middleware []middleware.Middleware
// Startup is a list of functions (or methods) to execute at // Startup is a list of functions (or methods) to execute at
// server startup and restart; these are executed before any // server startup and restart; these are executed before any
...@@ -65,13 +65,11 @@ func (c Config) Address() string { ...@@ -65,13 +65,11 @@ func (c Config) Address() string {
// TLSConfig describes how TLS should be configured and used. // TLSConfig describes how TLS should be configured and used.
type TLSConfig struct { type TLSConfig struct {
Enabled bool Enabled bool // will be set to true if TLS is enabled
Certificate string LetsEncryptEmail string
Key string Manual bool // will be set to true if user provides own certs and keys
LetsEncryptEmail string Managed bool // will be set to true if config qualifies for implicit automatic/managed HTTPS
Managed bool // will be set to true if config qualifies for automatic, managed TLS OnDemand bool // will be set to true if user enables on-demand TLS (obtain certs during handshakes)
//DisableHTTPRedir bool // TODO: not a good idea - should we really allow it?
OCSPStaple []byte
Ciphers []uint16 Ciphers []uint16
ProtocolMinVersion uint16 ProtocolMinVersion uint16
ProtocolMaxVersion uint16 ProtocolMaxVersion uint16
......
...@@ -15,8 +15,6 @@ import ( ...@@ -15,8 +15,6 @@ import (
"runtime" "runtime"
"sync" "sync"
"time" "time"
"golang.org/x/net/http2"
) )
// Server represents an instance of a server, which serves // Server represents an instance of a server, which serves
...@@ -26,8 +24,9 @@ import ( ...@@ -26,8 +24,9 @@ import (
// graceful termination (POSIX only). // graceful termination (POSIX only).
type Server struct { type Server struct {
*http.Server *http.Server
HTTP2 bool // temporary while http2 is not in std lib (TODO: remove flag when part of std lib) HTTP2 bool // whether to enable HTTP/2
tls bool // whether this server is serving all HTTPS hosts or not tls bool // whether this server is serving all HTTPS hosts or not
OnDemandTLS bool // whether this server supports on-demand TLS (load certs at handshake-time)
vhosts map[string]virtualHost // virtual hosts keyed by their address vhosts map[string]virtualHost // virtual hosts keyed by their address
listener ListenerFile // the listener which is bound to the socket listener ListenerFile // the listener which is bound to the socket
listenerMu sync.Mutex // protects listener listenerMu sync.Mutex // protects listener
...@@ -35,6 +34,7 @@ type Server struct { ...@@ -35,6 +34,7 @@ type Server struct {
startChan chan struct{} // used to block until server is finished starting startChan chan struct{} // used to block until server is finished starting
connTimeout time.Duration // the maximum duration of a graceful shutdown connTimeout time.Duration // the maximum duration of a graceful shutdown
ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request
SNICallback func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
} }
// ListenerFile represents a listener. // ListenerFile represents a listener.
...@@ -60,20 +60,23 @@ type OptionalCallback func(http.ResponseWriter, *http.Request) bool ...@@ -60,20 +60,23 @@ type OptionalCallback func(http.ResponseWriter, *http.Request) bool
// as it stands, you should dispose of a server after stopping it. // as it stands, you should dispose of a server after stopping it.
// The behavior of serving with a spent server is undefined. // The behavior of serving with a spent server is undefined.
func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server, error) { func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server, error) {
var tls bool var useTLS, useOnDemandTLS bool
if len(configs) > 0 { if len(configs) > 0 {
tls = configs[0].TLS.Enabled useTLS = configs[0].TLS.Enabled
useOnDemandTLS = configs[0].TLS.OnDemand
} }
s := &Server{ s := &Server{
Server: &http.Server{ Server: &http.Server{
Addr: addr, Addr: addr,
TLSConfig: new(tls.Config),
// TODO: Make these values configurable? // TODO: Make these values configurable?
// ReadTimeout: 2 * time.Minute, // ReadTimeout: 2 * time.Minute,
// WriteTimeout: 2 * time.Minute, // WriteTimeout: 2 * time.Minute,
// MaxHeaderBytes: 1 << 16, // MaxHeaderBytes: 1 << 16,
}, },
tls: tls, tls: useTLS,
OnDemandTLS: useOnDemandTLS,
vhosts: make(map[string]virtualHost), vhosts: make(map[string]virtualHost),
startChan: make(chan struct{}), startChan: make(chan struct{}),
connTimeout: gracefulTimeout, connTimeout: gracefulTimeout,
...@@ -168,7 +171,7 @@ func (s *Server) serve(ln ListenerFile) error { ...@@ -168,7 +171,7 @@ func (s *Server) serve(ln ListenerFile) error {
for _, vh := range s.vhosts { for _, vh := range s.vhosts {
tlsConfigs = append(tlsConfigs, vh.config.TLS) tlsConfigs = append(tlsConfigs, vh.config.TLS)
} }
return serveTLSWithSNI(s, s.listener, tlsConfigs) return serveTLS(s, s.listener, tlsConfigs)
} }
close(s.startChan) // unblock anyone waiting for this to start listening close(s.startChan) // unblock anyone waiting for this to start listening
...@@ -179,9 +182,8 @@ func (s *Server) serve(ln ListenerFile) error { ...@@ -179,9 +182,8 @@ func (s *Server) serve(ln ListenerFile) error {
// called just before the listener announces itself on the network // called just before the listener announces itself on the network
// and should only be called when the server is just starting up. // and should only be called when the server is just starting up.
func (s *Server) setup() error { func (s *Server) setup() error {
if s.HTTP2 { if !s.HTTP2 {
// TODO: This call may not be necessary after HTTP/2 is merged into std lib s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
http2.ConfigureServer(s.Server, nil)
} }
// Execute startup functions now // Execute startup functions now
...@@ -197,41 +199,17 @@ func (s *Server) setup() error { ...@@ -197,41 +199,17 @@ func (s *Server) setup() error {
return nil return nil
} }
// serveTLSWithSNI serves TLS with Server Name Indication (SNI) support, which allows // serveTLS serves TLS with SNI and client auth support if s has them enabled. It
// multiple sites (different hostnames) to be served from the same address. It also // blocks until s quits.
// supports client authentication if srv has it enabled. It blocks until s quits. func serveTLS(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
//
// This method is adapted from the std lib's net/http ServeTLS function, which was written
// by the Go Authors. It has been modified to support multiple certificate/key pairs,
// client authentication, and our custom Server type.
func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
config := cloneTLSConfig(s.TLSConfig)
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1"}
}
// Here we diverge from the stdlib a bit by loading multiple certs/key pairs
// then we map the server names to their certs
var err error
config.Certificates = make([]tls.Certificate, len(tlsConfigs))
for i, tlsConfig := range tlsConfigs {
config.Certificates[i], err = tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key)
config.Certificates[i].OCSPStaple = tlsConfig.OCSPStaple
if err != nil {
defer close(s.startChan)
return err
}
}
config.BuildNameToCertificate()
// Customize our TLS configuration // Customize our TLS configuration
config.MinVersion = tlsConfigs[0].ProtocolMinVersion s.TLSConfig.MinVersion = tlsConfigs[0].ProtocolMinVersion
config.MaxVersion = tlsConfigs[0].ProtocolMaxVersion s.TLSConfig.MaxVersion = tlsConfigs[0].ProtocolMaxVersion
config.CipherSuites = tlsConfigs[0].Ciphers s.TLSConfig.CipherSuites = tlsConfigs[0].Ciphers
config.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites s.TLSConfig.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites
// TLS client authentication, if user enabled it // TLS client authentication, if user enabled it
err = setupClientAuth(tlsConfigs, config) err := setupClientAuth(tlsConfigs, s.TLSConfig)
if err != nil { if err != nil {
defer close(s.startChan) defer close(s.startChan)
return err return err
...@@ -241,7 +219,7 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error { ...@@ -241,7 +219,7 @@ func serveTLSWithSNI(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
// with this TLS listener; tls.listener is unexported and does // with this TLS listener; tls.listener is unexported and does
// not implement the File() method we need for graceful restarts // not implement the File() method we need for graceful restarts
// on POSIX systems. // on POSIX systems.
ln = tls.NewListener(ln, config) ln = tls.NewListener(ln, s.TLSConfig)
close(s.startChan) // unblock anyone waiting for this to start listening close(s.startChan) // unblock anyone waiting for this to start listening
return s.Server.Serve(ln) return s.Server.Serve(ln)
...@@ -345,9 +323,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -345,9 +323,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
DefaultErrorFunc(w, r, status) DefaultErrorFunc(w, r, status)
} }
} else { } else {
// Get the remote host
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
remoteHost = r.RemoteAddr
}
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "No such host at %s", s.Server.Addr) fmt.Fprintf(w, "No such host at %s", s.Server.Addr)
log.Printf("[INFO] %s - No such host at %s", host, s.Server.Addr) log.Printf("[INFO] %s - No such host at %s (Remote: %s, Referer: %s)",
host, s.Server.Addr, remoteHost, r.Header.Get("Referer"))
} }
} }
...@@ -432,34 +417,6 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) { ...@@ -432,34 +417,6 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) {
return ln.TCPListener.File() return ln.TCPListener.File()
} }
// copied from net/http/transport.go
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}
// ShutdownCallbacks executes all the shutdown callbacks // ShutdownCallbacks executes all the shutdown callbacks
// for all the virtualhosts in servers, and returns all the // for all the virtualhosts in servers, and returns all the
// errors generated during their execution. In other words, // errors generated during their execution. In other words,
......
...@@ -21,13 +21,7 @@ type virtualHost struct { ...@@ -21,13 +21,7 @@ type virtualHost struct {
// ListenAndServe begins. // ListenAndServe begins.
func (vh *virtualHost) buildStack() error { func (vh *virtualHost) buildStack() error {
vh.fileServer = middleware.FileServer(http.Dir(vh.config.Root), []string{vh.config.ConfigFile}) vh.fileServer = middleware.FileServer(http.Dir(vh.config.Root), []string{vh.config.ConfigFile})
vh.compile(vh.config.Middleware)
// TODO: We only compile middleware for the "/" scope.
// Partial support for multiple location contexts already
// exists at the parser and config levels, but until full
// support is implemented, this is all we do right here.
vh.compile(vh.config.Middleware["/"])
return nil return nil
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment