Commit 94532246 authored by Matthew Holt's avatar Matthew Holt

Merge branch 'letsencryptfix'

parents a3f3bc67 fd176597
// Package caddy implements the Caddy web server as a service. // Package caddy implements the Caddy web server as a service
// in your own Go programs.
// //
// To use this package, follow a few simple steps: // To use this package, follow a few simple steps:
// //
...@@ -191,6 +192,7 @@ func startServers(groupings bindingGroup) error { ...@@ -191,6 +192,7 @@ func startServers(groupings bindingGroup) error {
return err return err
} }
s.HTTP2 = HTTP2 // TODO: This setting is temporary s.HTTP2 = HTTP2 // TODO: This setting is temporary
s.ReqCallback = letsencrypt.RequestCallback // ensures we can solve ACME challenges while running
var ln server.ListenerFile var ln server.ListenerFile
if IsRestart() { if IsRestart() {
......
...@@ -28,7 +28,7 @@ func ToJSON(caddyfile []byte) ([]byte, error) { ...@@ -28,7 +28,7 @@ func ToJSON(caddyfile []byte) ([]byte, error) {
// Fill up host list // Fill up host list
for _, host := range sb.HostList() { for _, host := range sb.HostList() {
block.Hosts = append(block.Hosts, strings.TrimSuffix(host, ":")) block.Hosts = append(block.Hosts, standardizeScheme(host))
} }
// Extract directives deterministically by sorting them // Extract directives deterministically by sorting them
...@@ -62,7 +62,6 @@ func ToJSON(caddyfile []byte) ([]byte, error) { ...@@ -62,7 +62,6 @@ func ToJSON(caddyfile []byte) ([]byte, error) {
// but only one line at a time, to be used at the top-level of // but only one line at a time, to be used at the top-level of
// a server block only (where the first token on each line is a // a server block only (where the first token on each line is a
// directive) - not to be used at any other nesting level. // directive) - not to be used at any other nesting level.
// goes to end of line
func constructLine(d *parse.Dispenser) []interface{} { func constructLine(d *parse.Dispenser) []interface{} {
var args []interface{} var args []interface{}
...@@ -80,8 +79,8 @@ func constructLine(d *parse.Dispenser) []interface{} { ...@@ -80,8 +79,8 @@ func constructLine(d *parse.Dispenser) []interface{} {
} }
// constructBlock recursively processes tokens into a // constructBlock recursively processes tokens into a
// JSON-encodable structure. // JSON-encodable structure. To be used in a directive's
// goes to end of block // block. Goes to end of block.
func constructBlock(d *parse.Dispenser) [][]interface{} { func constructBlock(d *parse.Dispenser) [][]interface{} {
block := [][]interface{}{} block := [][]interface{}{}
...@@ -110,15 +109,10 @@ func FromJSON(jsonBytes []byte) ([]byte, error) { ...@@ -110,15 +109,10 @@ func FromJSON(jsonBytes []byte) ([]byte, error) {
result += "\n\n" result += "\n\n"
} }
for i, host := range sb.Hosts { for i, host := range sb.Hosts {
if hostname, port, err := net.SplitHostPort(host); err == nil {
if port == "http" || port == "https" {
host = port + "://" + hostname
}
}
if i > 0 { if i > 0 {
result += ", " result += ", "
} }
result += strings.TrimSuffix(host, ":") result += standardizeScheme(host)
} }
result += jsonToText(sb.Body, 1) result += jsonToText(sb.Body, 1)
} }
...@@ -170,6 +164,17 @@ func jsonToText(scope interface{}, depth int) string { ...@@ -170,6 +164,17 @@ func jsonToText(scope interface{}, depth int) string {
return result return result
} }
// standardizeScheme turns an address like host:https into https://host,
// or "host:" into "host".
func standardizeScheme(addr string) string {
if hostname, port, err := net.SplitHostPort(addr); err == nil {
if port == "http" || port == "https" {
addr = port + "://" + hostname
}
}
return strings.TrimSuffix(addr, ":")
}
// Caddyfile encapsulates a slice of ServerBlocks. // Caddyfile encapsulates a slice of ServerBlocks.
type Caddyfile []ServerBlock type Caddyfile []ServerBlock
......
...@@ -63,7 +63,7 @@ baz" ...@@ -63,7 +63,7 @@ baz"
{ // 8 { // 8
caddyfile: `http://host, https://host { caddyfile: `http://host, https://host {
}`, }`,
json: `[{"hosts":["host:http","host:https"],"body":[]}]`, // hosts in JSON are always host:port format (if port is specified), for consistency json: `[{"hosts":["http://host","https://host"],"body":[]}]`, // hosts in JSON are always host:port format (if port is specified), for consistency
}, },
{ // 9 { // 9
caddyfile: `host { caddyfile: `host {
...@@ -124,3 +124,38 @@ func TestFromJSON(t *testing.T) { ...@@ -124,3 +124,38 @@ func TestFromJSON(t *testing.T) {
} }
} }
} }
func TestStandardizeAddress(t *testing.T) {
// host:https should be converted to https://host
output, err := ToJSON([]byte(`host:https`))
if err != nil {
t.Fatal(err)
}
if expected, actual := `[{"hosts":["https://host"],"body":[]}]`, string(output); expected != actual {
t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual)
}
output, err = FromJSON([]byte(`[{"hosts":["https://host"],"body":[]}]`))
if err != nil {
t.Fatal(err)
}
if expected, actual := "https://host {\n}", string(output); expected != actual {
t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual)
}
// host: should be converted to just host
output, err = ToJSON([]byte(`host:`))
if err != nil {
t.Fatal(err)
}
if expected, actual := `[{"hosts":["host"],"body":[]}]`, string(output); expected != actual {
t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual)
}
output, err = FromJSON([]byte(`[{"hosts":["host:"],"body":[]}]`))
if err != nil {
t.Fatal(err)
}
if expected, actual := "host {\n}", string(output); expected != actual {
t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual)
}
}
...@@ -21,25 +21,22 @@ const ( ...@@ -21,25 +21,22 @@ const (
DefaultConfigFile = "Caddyfile" DefaultConfigFile = "Caddyfile"
) )
// loadConfigs reads input (named filename) and parses it, returning the // loadConfigsUpToIncludingTLS loads the configs from input with name filename and returns them,
// server configurations in the order they appeared in the input. As part // the parsed server blocks, the index of the last directive it processed, and an error (if any).
// of this, it activates Let's Encrypt for the configs that are produced. func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Config, []parse.ServerBlock, int, error) {
// Thus, the returned configs are already optimally configured optimally
// for HTTPS.
func loadConfigs(filename string, input io.Reader) ([]server.Config, error) {
var configs []server.Config var configs []server.Config
// Each server block represents similar hosts/addresses, since they // Each server block represents similar hosts/addresses, since they
// were grouped together in the Caddyfile. // were grouped together in the Caddyfile.
serverBlocks, err := parse.ServerBlocks(filename, input, true) serverBlocks, err := parse.ServerBlocks(filename, input, true)
if err != nil { if err != nil {
return nil, err return nil, nil, 0, err
} }
if len(serverBlocks) == 0 { if len(serverBlocks) == 0 {
newInput := DefaultInput() newInput := DefaultInput()
serverBlocks, err = parse.ServerBlocks(newInput.Path(), bytes.NewReader(newInput.Body()), true) serverBlocks, err = parse.ServerBlocks(newInput.Path(), bytes.NewReader(newInput.Body()), true)
if err != nil { if err != nil {
return nil, err return nil, nil, 0, err
} }
} }
...@@ -56,6 +53,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { ...@@ -56,6 +53,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) {
config := server.Config{ config := server.Config{
Host: addr.Host, Host: addr.Host,
Port: addr.Port, Port: addr.Port,
Scheme: addr.Scheme,
Root: Root, Root: Root,
Middleware: make(map[string][]middleware.Middleware), Middleware: make(map[string][]middleware.Middleware),
ConfigFile: filename, ConfigFile: filename,
...@@ -88,7 +86,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { ...@@ -88,7 +86,7 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) {
// execute setup function and append middleware handler, if any // execute setup function and append middleware handler, if any
midware, err := dir.setup(controller) midware, err := dir.setup(controller)
if err != nil { if err != nil {
return nil, err return nil, nil, lastDirectiveIndex, err
} }
if midware != nil { if midware != nil {
// TODO: For now, we only support the default path scope / // TODO: For now, we only support the default path scope /
...@@ -109,22 +107,31 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { ...@@ -109,22 +107,31 @@ func loadConfigs(filename string, input io.Reader) ([]server.Config, error) {
} }
} }
return configs, serverBlocks, lastDirectiveIndex, nil
}
// loadConfigs reads input (named filename) and parses it, returning the
// server configurations in the order they appeared in the input. As part
// of this, it activates Let's Encrypt for the configs that are produced.
// Thus, the returned configs are already optimally configured for HTTPS.
func loadConfigs(filename string, input io.Reader) ([]server.Config, error) {
configs, serverBlocks, lastDirectiveIndex, err := loadConfigsUpToIncludingTLS(filename, input)
if err != nil {
return nil, err
}
// Now we have all the configs, but they have only been set up to the // Now we have all the configs, but they have only been set up to the
// point of tls. We need to activate Let's Encrypt before setting up // point of tls. We need to activate Let's Encrypt before setting up
// the rest of the middlewares so they have correct information regarding // the rest of the middlewares so they have correct information regarding
// TLS configuration, if necessary. (this call is append-only, so our // TLS configuration, if necessary. (this only appends, so our iterations
// iterations below shouldn't be affected) // over server blocks below shouldn't be affected)
if !IsRestart() && !Quiet { if !IsRestart() && !Quiet {
fmt.Print("Activating privacy features...") fmt.Print("Activating privacy features...")
} }
configs, err = letsencrypt.Activate(configs) configs, err = letsencrypt.Activate(configs)
if err != nil { if err != nil {
if !Quiet {
fmt.Println()
}
return nil, err return nil, err
} } else if !IsRestart() && !Quiet {
if !IsRestart() && !Quiet {
fmt.Println(" done.") fmt.Println(" done.")
} }
...@@ -277,46 +284,19 @@ func arrangeBindings(allConfigs []server.Config) (bindingGroup, error) { ...@@ -277,46 +284,19 @@ func arrangeBindings(allConfigs []server.Config) (bindingGroup, error) {
// but execution may continue. The second error, if not nil, is a real // but execution may continue. The second error, if not nil, is a real
// problem and the server should not be started. // problem and the server should not be started.
// //
// This function handles edge cases gracefully. If a port name like // This function does not handle edge cases like port "http" or "https" if
// "http" or "https" is unknown to the system, this function will // they are not known to the system. It does, however, serve on the wildcard
// change them to 80 or 443 respectively. If a hostname fails to // host if resolving the address of the specific hostname fails.
// resolve, that host can still be served but will be listening on
// the wildcard host instead. This function takes care of this for you.
func resolveAddr(conf server.Config) (resolvAddr *net.TCPAddr, warnErr, fatalErr error) { func resolveAddr(conf server.Config) (resolvAddr *net.TCPAddr, warnErr, fatalErr error) {
bindHost := conf.BindHost resolvAddr, warnErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.BindHost, conf.Port))
// TODO: Do we even need the port? Maybe we just need to look up the host.
resolvAddr, warnErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(bindHost, conf.Port))
if warnErr != nil { if warnErr != nil {
// Most likely the host lookup failed or the port is unknown
tryPort := conf.Port
switch errVal := warnErr.(type) {
case *net.AddrError:
if errVal.Err == "unknown port" {
// some odd Linux machines don't support these port names; see issue #136
switch conf.Port {
case "http":
tryPort = "80"
case "https":
tryPort = "443"
}
}
resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(bindHost, tryPort))
if fatalErr != nil {
return
}
default:
// the hostname probably couldn't be resolved, just bind to wildcard then // the hostname probably couldn't be resolved, just bind to wildcard then
resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort("0.0.0.0", tryPort)) resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort("", conf.Port))
if fatalErr != nil { if fatalErr != nil {
return return
} }
} }
return
}
return return
} }
...@@ -334,12 +314,12 @@ func validDirective(d string) bool { ...@@ -334,12 +314,12 @@ func validDirective(d string) bool {
// DefaultInput returns the default Caddyfile input // DefaultInput returns the default Caddyfile input
// to use when it is otherwise empty or missing. // to use when it is otherwise empty or missing.
// It uses the default host and port (depends on // It uses the default host and port (depends on
// host, e.g. localhost is 2015, otherwise https) and // host, e.g. localhost is 2015, otherwise 443) and
// root. // root.
func DefaultInput() CaddyfileInput { func DefaultInput() CaddyfileInput {
port := Port port := Port
if letsencrypt.HostQualifies(Host) { if letsencrypt.HostQualifies(Host) && port == DefaultPort {
port = "https" port = "443"
} }
return CaddyfileInput{ return CaddyfileInput{
Contents: []byte(fmt.Sprintf("%s:%s\nroot %s", Host, port, Root)), Contents: []byte(fmt.Sprintf("%s:%s\nroot %s", Host, port, Root)),
......
...@@ -13,10 +13,10 @@ func TestDefaultInput(t *testing.T) { ...@@ -13,10 +13,10 @@ func TestDefaultInput(t *testing.T) {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
} }
// next few tests simulate user providing -host flag // next few tests simulate user providing -host and/or -port flags
Host = "not-localhost.com" Host = "not-localhost.com"
if actual, expected := string(DefaultInput().Body()), "not-localhost.com:https\nroot ."; actual != expected { if actual, expected := string(DefaultInput().Body()), "not-localhost.com:443\nroot ."; actual != expected {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
} }
...@@ -29,6 +29,18 @@ func TestDefaultInput(t *testing.T) { ...@@ -29,6 +29,18 @@ func TestDefaultInput(t *testing.T) {
if actual, expected := string(DefaultInput().Body()), "127.0.1.1:2015\nroot ."; actual != expected { if actual, expected := string(DefaultInput().Body()), "127.0.1.1:2015\nroot ."; actual != expected {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
} }
Host = "not-localhost.com"
Port = "1234"
if actual, expected := string(DefaultInput().Body()), "not-localhost.com:1234\nroot ."; actual != expected {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
}
Host = DefaultHost
Port = "1234"
if actual, expected := string(DefaultInput().Body()), ":1234\nroot ."; actual != expected {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
}
} }
func TestResolveAddr(t *testing.T) { func TestResolveAddr(t *testing.T) {
...@@ -51,14 +63,14 @@ func TestResolveAddr(t *testing.T) { ...@@ -51,14 +63,14 @@ func TestResolveAddr(t *testing.T) {
{server.Config{Host: "localhost", Port: "80"}, false, false, "<nil>", 80}, {server.Config{Host: "localhost", Port: "80"}, false, false, "<nil>", 80},
{server.Config{BindHost: "localhost", Port: "1234"}, false, false, "127.0.0.1", 1234}, {server.Config{BindHost: "localhost", Port: "1234"}, false, false, "127.0.0.1", 1234},
{server.Config{BindHost: "127.0.0.1", Port: "1234"}, false, false, "127.0.0.1", 1234}, {server.Config{BindHost: "127.0.0.1", Port: "1234"}, false, false, "127.0.0.1", 1234},
{server.Config{BindHost: "should-not-resolve", Port: "1234"}, true, false, "0.0.0.0", 1234}, {server.Config{BindHost: "should-not-resolve", Port: "1234"}, true, false, "<nil>", 1234},
{server.Config{BindHost: "localhost", Port: "http"}, false, false, "127.0.0.1", 80}, {server.Config{BindHost: "localhost", Port: "http"}, false, false, "127.0.0.1", 80},
{server.Config{BindHost: "localhost", Port: "https"}, false, false, "127.0.0.1", 443}, {server.Config{BindHost: "localhost", Port: "https"}, false, false, "127.0.0.1", 443},
{server.Config{BindHost: "", Port: "1234"}, false, false, "<nil>", 1234}, {server.Config{BindHost: "", Port: "1234"}, false, false, "<nil>", 1234},
{server.Config{BindHost: "localhost", Port: "abcd"}, false, true, "", 0}, {server.Config{BindHost: "localhost", Port: "abcd"}, false, true, "", 0},
{server.Config{BindHost: "127.0.0.1", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234}, {server.Config{BindHost: "127.0.0.1", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234},
{server.Config{BindHost: "localhost", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234}, {server.Config{BindHost: "localhost", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234},
{server.Config{BindHost: "should-not-resolve", Host: "localhost", Port: "1234"}, true, false, "0.0.0.0", 1234}, {server.Config{BindHost: "should-not-resolve", Host: "localhost", Port: "1234"}, true, false, "<nil>", 1234},
} { } {
actualAddr, warnErr, fatalErr := resolveAddr(test.config) actualAddr, warnErr, fatalErr := resolveAddr(test.config)
......
...@@ -40,12 +40,12 @@ func TestSaveAndLoadRSAPrivateKey(t *testing.T) { ...@@ -40,12 +40,12 @@ func TestSaveAndLoadRSAPrivateKey(t *testing.T) {
} }
} }
// rsaPrivateKeyBytes returns the bytes of DER-encoded key.
func rsaPrivateKeyBytes(key *rsa.PrivateKey) []byte {
return x509.MarshalPKCS1PrivateKey(key)
}
// rsaPrivateKeysSame compares the bytes of a and b and returns true if they are the same. // rsaPrivateKeysSame compares the bytes of a and b and returns true if they are the same.
func rsaPrivateKeysSame(a, b *rsa.PrivateKey) bool { func rsaPrivateKeysSame(a, b *rsa.PrivateKey) bool {
return bytes.Equal(rsaPrivateKeyBytes(a), rsaPrivateKeyBytes(b)) return bytes.Equal(rsaPrivateKeyBytes(a), rsaPrivateKeyBytes(b))
} }
// rsaPrivateKeyBytes returns the bytes of DER-encoded key.
func rsaPrivateKeyBytes(key *rsa.PrivateKey) []byte {
return x509.MarshalPKCS1PrivateKey(key)
}
...@@ -2,30 +2,21 @@ package letsencrypt ...@@ -2,30 +2,21 @@ package letsencrypt
import ( import (
"crypto/tls" "crypto/tls"
"log"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"strings" "strings"
"github.com/mholt/caddy/middleware"
) )
const challengeBasePath = "/.well-known/acme-challenge" const challengeBasePath = "/.well-known/acme-challenge"
// Handler is a Caddy middleware that can proxy ACME challenge // RequestCallback proxies challenge requests to ACME client if the
// requests to the real ACME client endpoint. This is necessary // request path starts with challengeBasePath. It returns true if it
// to renew certificates while the server is running. // handled the request and no more needs to be done; it returns false
type Handler struct { // if this call was a no-op and the request still needs handling.
Next middleware.Handler func RequestCallback(w http.ResponseWriter, r *http.Request) bool {
//ChallengeActive int32 // (TODO) use sync/atomic to set/get this flag safely and efficiently
}
// ServeHTTP is basically a no-op unless an ACME challenge is active on this host
// and the request path matches the expected path exactly.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// Proxy challenge requests to ACME client
// TODO: Only do this if a challenge is active?
if strings.HasPrefix(r.URL.Path, challengeBasePath) { if strings.HasPrefix(r.URL.Path, challengeBasePath) {
scheme := "http" scheme := "http"
if r.TLS != nil { if r.TLS != nil {
...@@ -37,9 +28,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -37,9 +28,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
hostname = r.URL.Host hostname = r.URL.Host
} }
upstream, err := url.Parse(scheme + "://" + hostname + ":" + alternatePort) upstream, err := url.Parse(scheme + "://" + hostname + ":" + AlternatePort)
if err != nil { if err != nil {
return http.StatusInternalServerError, err w.WriteHeader(http.StatusInternalServerError)
log.Printf("[ERROR] letsencrypt handler: %v", err)
return true
} }
proxy := httputil.NewSingleHostReverseProxy(upstream) proxy := httputil.NewSingleHostReverseProxy(upstream)
...@@ -48,8 +41,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -48,8 +41,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
} }
proxy.ServeHTTP(w, r) proxy.ServeHTTP(w, r)
return 0, nil return true
} }
return h.Next.ServeHTTP(w, r) return false
} }
package letsencrypt
import (
"net"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestCallbackNoOp(t *testing.T) {
// try base paths that aren't handled by this handler
for _, url := range []string{
"http://localhost/",
"http://localhost/foo.html",
"http://localhost/.git",
"http://localhost/.well-known/",
"http://localhost/.well-known/acme-challenging",
} {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Fatalf("Could not craft request, got error: %v", err)
}
rw := httptest.NewRecorder()
if RequestCallback(rw, req) {
t.Errorf("Got true with this URL, but shouldn't have: %s", url)
}
}
}
func TestRequestCallbackSuccess(t *testing.T) {
expectedPath := challengeBasePath + "/asdf"
// Set up fake acme handler backend to make sure proxying succeeds
var proxySuccess bool
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxySuccess = true
if r.URL.Path != expectedPath {
t.Errorf("Expected path '%s' but got '%s' instead", expectedPath, r.URL.Path)
}
}))
// Custom listener that uses the port we expect
ln, err := net.Listen("tcp", "127.0.0.1:"+AlternatePort)
if err != nil {
t.Fatalf("Unable to start test server listener: %v", err)
}
ts.Listener = ln
// Start our engines and run the test
ts.Start()
defer ts.Close()
req, err := http.NewRequest("GET", "http://127.0.0.1:"+AlternatePort+expectedPath, nil)
if err != nil {
t.Fatalf("Could not craft request, got error: %v", err)
}
rw := httptest.NewRecorder()
RequestCallback(rw, req)
if !proxySuccess {
t.Fatal("Expected request to be proxied, but it wasn't")
}
}
This diff is collapsed.
This diff is collapsed.
...@@ -27,8 +27,8 @@ var OnChange func() error ...@@ -27,8 +27,8 @@ var OnChange func() error
// which you'll close when maintenance should stop, to allow this // which you'll close when maintenance should stop, to allow this
// goroutine to clean up after itself and unblock. // goroutine to clean up after itself and unblock.
func maintainAssets(configs []server.Config, stopChan chan struct{}) { func maintainAssets(configs []server.Config, stopChan chan struct{}) {
renewalTicker := time.NewTicker(renewInterval) renewalTicker := time.NewTicker(RenewInterval)
ocspTicker := time.NewTicker(ocspInterval) ocspTicker := time.NewTicker(OCSPInterval)
for { for {
select { select {
...@@ -47,17 +47,31 @@ func maintainAssets(configs []server.Config, stopChan chan struct{}) { ...@@ -47,17 +47,31 @@ func maintainAssets(configs []server.Config, stopChan chan struct{}) {
} }
} }
case <-ocspTicker.C: case <-ocspTicker.C:
for bundle, oldStatus := range ocspStatus { for bundle, oldResp := range ocspCache {
_, newStatus, err := acme.GetOCSPForCert(*bundle) // start checking OCSP staple about halfway through validity period for good measure
if err == nil && newStatus != oldStatus && OnChange != nil { refreshTime := oldResp.ThisUpdate.Add(oldResp.NextUpdate.Sub(oldResp.ThisUpdate) / 2)
log.Printf("[INFO] OCSP status changed from %v to %v", oldStatus, newStatus)
// only check for updated OCSP validity window if refreshTime is in the past
if time.Now().After(refreshTime) {
_, newResp, err := acme.GetOCSPForCert(*bundle)
if err != nil {
log.Printf("[ERROR] Checking OCSP for bundle: %v", err)
continue
}
// we're not looking for different status, just a more future expiration
if newResp.NextUpdate != oldResp.NextUpdate {
if OnChange != nil {
log.Printf("[INFO] Updating OCSP stapling to extend validity period to %v", newResp.NextUpdate)
err := OnChange() err := OnChange()
if err != nil { if err != nil {
log.Printf("[ERROR] OnChange after OCSP update: %v", err) log.Printf("[ERROR] OnChange after OCSP trigger: %v", err)
} }
break break
} }
} }
}
}
case <-stopChan: case <-stopChan:
renewalTicker.Stop() renewalTicker.Stop()
ocspTicker.Stop() ocspTicker.Stop()
...@@ -102,12 +116,12 @@ func renewCertificates(configs []server.Config, useCustomPort bool) (int, []erro ...@@ -102,12 +116,12 @@ func renewCertificates(configs []server.Config, useCustomPort bool) (int, []erro
// Directly convert it to days for the following checks. // Directly convert it to days for the following checks.
daysLeft := int(expTime.Sub(time.Now().UTC()).Hours() / 24) daysLeft := int(expTime.Sub(time.Now().UTC()).Hours() / 24)
// Renew with two weeks or less remaining. // Renew if getting close to expiration.
if daysLeft <= 14 { if daysLeft <= renewDaysBefore {
log.Printf("[INFO] Certificate for %s has %d days remaining; attempting renewal", cfg.Host, daysLeft) log.Printf("[INFO] Certificate for %s has %d days remaining; attempting renewal", cfg.Host, daysLeft)
var client *acme.Client var client *acme.Client
if useCustomPort { if useCustomPort {
client, err = newClientPort("", alternatePort) // email not used for renewal client, err = newClientPort("", AlternatePort) // email not used for renewal
} else { } else {
client, err = newClient("") client, err = newClient("")
} }
...@@ -134,7 +148,7 @@ func renewCertificates(configs []server.Config, useCustomPort bool) (int, []erro ...@@ -134,7 +148,7 @@ func renewCertificates(configs []server.Config, useCustomPort bool) (int, []erro
// Renew certificate // Renew certificate
Renew: Renew:
newCertMeta, err := client.RenewCertificate(certMeta, true, true) newCertMeta, err := client.RenewCertificate(certMeta, true)
if err != nil { if err != nil {
if _, ok := err.(acme.TOSError); ok { if _, ok := err.(acme.TOSError); ok {
err := client.AgreeToTOS() err := client.AgreeToTOS()
...@@ -145,24 +159,22 @@ func renewCertificates(configs []server.Config, useCustomPort bool) (int, []erro ...@@ -145,24 +159,22 @@ func renewCertificates(configs []server.Config, useCustomPort bool) (int, []erro
} }
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
newCertMeta, err = client.RenewCertificate(certMeta, true, true) newCertMeta, err = client.RenewCertificate(certMeta, true)
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
continue continue
} }
} }
saveCertsAndKeys([]acme.CertificateResource{newCertMeta}) saveCertResource(newCertMeta)
n++ n++
} else if daysLeft <= 30 { } else if daysLeft <= renewDaysBefore+7 && daysLeft >= renewDaysBefore+6 {
// Warn on 30 days remaining. TODO: Just do this once... log.Printf("[WARNING] Certificate for %s has %d days remaining; will automatically renew when %d days remain\n", cfg.Host, daysLeft, renewDaysBefore)
log.Printf("[WARNING] Certificate for %s has %d days remaining; will automatically renew when 14 days remain\n", cfg.Host, daysLeft)
} }
} }
return n, errs return n, errs
} }
// acmeHandlers is a map of host to ACME handler. These // renewDaysBefore is how many days before expiration to renew certificates.
// are used to proxy ACME requests to the ACME client. const renewDaysBefore = 14
var acmeHandlers = make(map[string]*Handler)
...@@ -6,44 +6,44 @@ import ( ...@@ -6,44 +6,44 @@ import (
) )
func TestStorage(t *testing.T) { func TestStorage(t *testing.T) {
storage = Storage("./letsencrypt") storage = Storage("./le_test")
if expected, actual := filepath.Join("letsencrypt", "sites"), storage.Sites(); actual != expected { if expected, actual := filepath.Join("le_test", "sites"), storage.Sites(); actual != expected {
t.Errorf("Expected Sites() to return '%s' but got '%s'", expected, actual) t.Errorf("Expected Sites() to return '%s' but got '%s'", expected, actual)
} }
if expected, actual := filepath.Join("letsencrypt", "sites", "test.com"), storage.Site("test.com"); actual != expected { if expected, actual := filepath.Join("le_test", "sites", "test.com"), storage.Site("test.com"); actual != expected {
t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual) t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual)
} }
if expected, actual := filepath.Join("letsencrypt", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("test.com"); actual != expected { if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("test.com"); actual != expected {
t.Errorf("Expected SiteCertFile() to return '%s' but got '%s'", expected, actual) t.Errorf("Expected SiteCertFile() to return '%s' but got '%s'", expected, actual)
} }
if expected, actual := filepath.Join("letsencrypt", "sites", "test.com", "test.com.key"), storage.SiteKeyFile("test.com"); actual != expected { if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.key"), storage.SiteKeyFile("test.com"); actual != expected {
t.Errorf("Expected SiteKeyFile() to return '%s' but got '%s'", expected, actual) t.Errorf("Expected SiteKeyFile() to return '%s' but got '%s'", expected, actual)
} }
if expected, actual := filepath.Join("letsencrypt", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("test.com"); actual != expected { if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("test.com"); actual != expected {
t.Errorf("Expected SiteMetaFile() to return '%s' but got '%s'", expected, actual) t.Errorf("Expected SiteMetaFile() to return '%s' but got '%s'", expected, actual)
} }
if expected, actual := filepath.Join("letsencrypt", "users"), storage.Users(); actual != expected { if expected, actual := filepath.Join("le_test", "users"), storage.Users(); actual != expected {
t.Errorf("Expected Users() to return '%s' but got '%s'", expected, actual) t.Errorf("Expected Users() to return '%s' but got '%s'", expected, actual)
} }
if expected, actual := filepath.Join("letsencrypt", "users", "me@example.com"), storage.User("me@example.com"); actual != expected { if expected, actual := filepath.Join("le_test", "users", "me@example.com"), storage.User("me@example.com"); actual != expected {
t.Errorf("Expected User() to return '%s' but got '%s'", expected, actual) t.Errorf("Expected User() to return '%s' but got '%s'", expected, actual)
} }
if expected, actual := filepath.Join("letsencrypt", "users", "me@example.com", "me.json"), storage.UserRegFile("me@example.com"); actual != expected { if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.json"), storage.UserRegFile("me@example.com"); actual != expected {
t.Errorf("Expected UserRegFile() to return '%s' but got '%s'", expected, actual) t.Errorf("Expected UserRegFile() to return '%s' but got '%s'", expected, actual)
} }
if expected, actual := filepath.Join("letsencrypt", "users", "me@example.com", "me.key"), storage.UserKeyFile("me@example.com"); actual != expected { if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.key"), storage.UserKeyFile("me@example.com"); actual != expected {
t.Errorf("Expected UserKeyFile() to return '%s' but got '%s'", expected, actual) t.Errorf("Expected UserKeyFile() to return '%s' but got '%s'", expected, actual)
} }
// Test with empty emails // Test with empty emails
if expected, actual := filepath.Join("letsencrypt", "users", emptyEmail), storage.User(emptyEmail); actual != expected { if expected, actual := filepath.Join("le_test", "users", emptyEmail), storage.User(emptyEmail); actual != expected {
t.Errorf("Expected User(\"\") to return '%s' but got '%s'", expected, actual) t.Errorf("Expected User(\"\") to return '%s' but got '%s'", expected, actual)
} }
if expected, actual := filepath.Join("letsencrypt", "users", emptyEmail, emptyEmail+".json"), storage.UserRegFile(""); actual != expected { if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".json"), storage.UserRegFile(""); actual != expected {
t.Errorf("Expected UserRegFile(\"\") to return '%s' but got '%s'", expected, actual) t.Errorf("Expected UserRegFile(\"\") to return '%s' but got '%s'", expected, actual)
} }
if expected, actual := filepath.Join("letsencrypt", "users", emptyEmail, emptyEmail+".key"), storage.UserKeyFile(""); actual != expected { if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".key"), storage.UserKeyFile(""); actual != expected {
t.Errorf("Expected UserKeyFile(\"\") to return '%s' but got '%s'", expected, actual) t.Errorf("Expected UserKeyFile(\"\") to return '%s' but got '%s'", expected, actual)
} }
} }
......
...@@ -144,7 +144,7 @@ func getEmail(cfg server.Config) string { ...@@ -144,7 +144,7 @@ func getEmail(cfg server.Config) string {
// Alas, we must bother the user and ask for an email address; // Alas, we must bother the user and ask for an email address;
// if they proceed they also agree to the SA. // if they proceed they also agree to the SA.
reader := bufio.NewReader(stdin) reader := bufio.NewReader(stdin)
fmt.Println("Your sites will be served over HTTPS automatically using Let's Encrypt.") fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.")
fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:") fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:")
fmt.Println(" " + saURL) // TODO: Show current SA link fmt.Println(" " + saURL) // TODO: Show current SA link
fmt.Println("Please enter your email address so you can recover your account if needed.") fmt.Println("Please enter your email address so you can recover your account if needed.")
......
...@@ -125,6 +125,11 @@ func TestGetUserAlreadyExists(t *testing.T) { ...@@ -125,6 +125,11 @@ func TestGetUserAlreadyExists(t *testing.T) {
} }
func TestGetEmail(t *testing.T) { func TestGetEmail(t *testing.T) {
// let's not clutter up the output
origStdout := os.Stdout
os.Stdout = nil
defer func() { os.Stdout = origStdout }()
storage = Storage("./testdata") storage = Storage("./testdata")
defer os.RemoveAll(string(storage)) defer os.RemoveAll(string(storage))
DefaultEmail = "test2@foo.com" DefaultEmail = "test2@foo.com"
......
...@@ -8,7 +8,7 @@ import "io" ...@@ -8,7 +8,7 @@ import "io"
// If checkDirectives is true, only valid directives will be allowed // If checkDirectives is true, only valid directives will be allowed
// otherwise we consider it a parse error. Server blocks are returned // otherwise we consider it a parse error. Server blocks are returned
// in the order in which they appear. // in the order in which they appear.
func ServerBlocks(filename string, input io.Reader, checkDirectives bool) ([]serverBlock, error) { func ServerBlocks(filename string, input io.Reader, checkDirectives bool) ([]ServerBlock, error) {
p := parser{Dispenser: NewDispenser(filename, input)} p := parser{Dispenser: NewDispenser(filename, input)}
p.checkDirectives = checkDirectives p.checkDirectives = checkDirectives
blocks, err := p.parseAll() blocks, err := p.parseAll()
......
package parse package parse
import ( import (
"fmt"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
...@@ -9,13 +10,13 @@ import ( ...@@ -9,13 +10,13 @@ import (
type parser struct { type parser struct {
Dispenser Dispenser
block serverBlock // current server block being parsed block ServerBlock // current server block being parsed
eof bool // if we encounter a valid EOF in a hard place eof bool // if we encounter a valid EOF in a hard place
checkDirectives bool // if true, directives must be known checkDirectives bool // if true, directives must be known
} }
func (p *parser) parseAll() ([]serverBlock, error) { func (p *parser) parseAll() ([]ServerBlock, error) {
var blocks []serverBlock var blocks []ServerBlock
for p.Next() { for p.Next() {
err := p.parseOne() err := p.parseOne()
...@@ -31,7 +32,7 @@ func (p *parser) parseAll() ([]serverBlock, error) { ...@@ -31,7 +32,7 @@ func (p *parser) parseAll() ([]serverBlock, error) {
} }
func (p *parser) parseOne() error { func (p *parser) parseOne() error {
p.block = serverBlock{Tokens: make(map[string][]token)} p.block = ServerBlock{Tokens: make(map[string][]token)}
err := p.begin() err := p.begin()
if err != nil { if err != nil {
...@@ -99,11 +100,11 @@ func (p *parser) addresses() error { ...@@ -99,11 +100,11 @@ func (p *parser) addresses() error {
} }
// Parse and save this address // Parse and save this address
host, port, err := standardAddress(tkn) addr, err := standardAddress(tkn)
if err != nil { if err != nil {
return err return err
} }
p.block.Addresses = append(p.block.Addresses, address{host, port}) p.block.Addresses = append(p.block.Addresses, addr)
} }
// Advance token and possibly break out of loop or return error // Advance token and possibly break out of loop or return error
...@@ -303,39 +304,57 @@ func (p *parser) closeCurlyBrace() error { ...@@ -303,39 +304,57 @@ func (p *parser) closeCurlyBrace() error {
return nil return nil
} }
// standardAddress turns the accepted host and port patterns // standardAddress parses an address string into a structured format with separate
// into a format accepted by net.Dial. // scheme, host, and port portions, as well as the original input string.
func standardAddress(str string) (host, port string, err error) { func standardAddress(str string) (address, error) {
var schemePort, splitPort string var scheme string
var err error
// first check for scheme and strip it off
input := str
if strings.HasPrefix(str, "https://") { if strings.HasPrefix(str, "https://") {
schemePort = "https" scheme = "https"
str = str[8:] str = str[8:]
} else if strings.HasPrefix(str, "http://") { } else if strings.HasPrefix(str, "http://") {
schemePort = "http" scheme = "http"
str = str[7:] str = str[7:]
} }
host, splitPort, err = net.SplitHostPort(str) // separate host and port
host, port, err := net.SplitHostPort(str)
if err != nil { if err != nil {
host, splitPort, err = net.SplitHostPort(str + ":") // tack on empty port host, port, err = net.SplitHostPort(str + ":")
// no error check here; return err at end of function
}
// see if we can set port based off scheme
if port == "" {
if scheme == "http" {
port = "80"
} else if scheme == "https" {
port = "443"
} }
if err != nil {
// ¯\_(ツ)_/¯
host = str
} }
if splitPort != "" { // repeated or conflicting scheme is confusing, so error
port = splitPort if scheme != "" && (port == "http" || port == "https") {
} else { return address{}, fmt.Errorf("[%s] scheme specified twice in address", str)
port = schemePort }
// standardize http and https ports to their respective port numbers
if port == "http" {
scheme = "http"
port = "80"
} else if port == "https" {
scheme = "https"
port = "443"
} }
return return address{Original: input, Scheme: scheme, Host: host, Port: port}, err
} }
// replaceEnvVars replaces environment variables that appear in the token // replaceEnvVars replaces environment variables that appear in the token
// and understands both the Unix $SYNTAX and Windows %SYNTAX%. // and understands both the $UNIX and %WINDOWS% syntaxes.
func replaceEnvVars(s string) string { func replaceEnvVars(s string) string {
s = replaceEnvReferences(s, "{%", "%}") s = replaceEnvReferences(s, "{%", "%}")
s = replaceEnvReferences(s, "{$", "}") s = replaceEnvReferences(s, "{$", "}")
...@@ -360,26 +379,26 @@ func replaceEnvReferences(s, refStart, refEnd string) string { ...@@ -360,26 +379,26 @@ func replaceEnvReferences(s, refStart, refEnd string) string {
} }
type ( type (
// serverBlock associates tokens with a list of addresses // ServerBlock associates tokens with a list of addresses
// and groups tokens by directive name. // and groups tokens by directive name.
serverBlock struct { ServerBlock struct {
Addresses []address Addresses []address
Tokens map[string][]token Tokens map[string][]token
} }
address struct { address struct {
Host, Port string Original, Scheme, Host, Port string
} }
) )
// HostList converts the list of addresses (hosts) // HostList converts the list of addresses that are
// that are associated with this server block into // associated with this server block into a slice of
// a slice of strings. Each string is a host:port // strings, where each address is as it was originally
// combination. // read from the input.
func (sb serverBlock) HostList() []string { func (sb ServerBlock) HostList() []string {
sbHosts := make([]string, len(sb.Addresses)) sbHosts := make([]string, len(sb.Addresses))
for j, addr := range sb.Addresses { for j, addr := range sb.Addresses {
sbHosts[j] = net.JoinHostPort(addr.Host, addr.Port) sbHosts[j] = addr.Original
} }
return sbHosts return sbHosts
} }
This diff is collapsed.
...@@ -3,11 +3,17 @@ ...@@ -3,11 +3,17 @@
package caddy package caddy
import ( import (
"bytes"
"encoding/gob" "encoding/gob"
"errors"
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"os/exec" "os/exec"
"path"
"github.com/mholt/caddy/caddy/letsencrypt"
"github.com/mholt/caddy/server"
) )
func init() { func init() {
...@@ -33,6 +39,12 @@ func Restart(newCaddyfile Input) error { ...@@ -33,6 +39,12 @@ func Restart(newCaddyfile Input) error {
caddyfileMu.Unlock() caddyfileMu.Unlock()
} }
// Get certificates for any new hosts in the new Caddyfile without causing downtime
err := getCertsForNewCaddyfile(newCaddyfile)
if err != nil {
return errors.New("TLS preload: " + err.Error())
}
if len(os.Args) == 0 { // this should never happen, but... if len(os.Args) == 0 { // this should never happen, but...
os.Args = []string{""} os.Args = []string{""}
} }
...@@ -61,7 +73,7 @@ func Restart(newCaddyfile Input) error { ...@@ -61,7 +73,7 @@ func Restart(newCaddyfile Input) error {
// Pass along relevant file descriptors to child process; ordering // Pass along relevant file descriptors to child process; ordering
// is very important since we rely on these being in certain positions. // is very important since we rely on these being in certain positions.
extraFiles := []*os.File{sigwpipe} extraFiles := []*os.File{sigwpipe} // fd 3
// Add file descriptors of all the sockets // Add file descriptors of all the sockets
serversMu.Lock() serversMu.Lock()
...@@ -110,3 +122,44 @@ func Restart(newCaddyfile Input) error { ...@@ -110,3 +122,44 @@ func Restart(newCaddyfile Input) error {
// Looks like child is successful; we can exit gracefully. // Looks like child is successful; we can exit gracefully.
return Stop() return Stop()
} }
func getCertsForNewCaddyfile(newCaddyfile Input) error {
// parse the new caddyfile only up to (and including) TLS
// so we can know what we need to get certs for.
configs, _, _, err := loadConfigsUpToIncludingTLS(path.Base(newCaddyfile.Path()), bytes.NewReader(newCaddyfile.Body()))
if err != nil {
return errors.New("loading Caddyfile: " + err.Error())
}
// first mark the configs that are qualified for managed TLS
letsencrypt.MarkQualified(configs)
// we must make sure port is set before we group by bind address
letsencrypt.EnableTLS(configs)
// we only need to issue certs for hosts where we already have an active listener
groupings, err := arrangeBindings(configs)
if err != nil {
return errors.New("arranging bindings: " + err.Error())
}
var configsToSetup []server.Config
serversMu.Lock()
GroupLoop:
for _, group := range groupings {
for _, server := range servers {
if server.Addr == group.BindAddr.String() {
configsToSetup = append(configsToSetup, group.Configs...)
continue GroupLoop
}
}
}
serversMu.Unlock()
// place certs on the disk
err = letsencrypt.ObtainCerts(configsToSetup, letsencrypt.AlternatePort)
if err != nil {
return errors.New("obtaining certs: " + err.Error())
}
return nil
}
...@@ -2,7 +2,6 @@ package setup ...@@ -2,7 +2,6 @@ package setup
import ( import (
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"strconv" "strconv"
"testing" "testing"
...@@ -13,16 +12,19 @@ import ( ...@@ -13,16 +12,19 @@ import (
// because the Startup and Shutdown functions share virtually the // because the Startup and Shutdown functions share virtually the
// same functionality // same functionality
func TestStartup(t *testing.T) { func TestStartup(t *testing.T) {
tempDirPath, err := getTempDirPath() tempDirPath, err := getTempDirPath()
if err != nil { if err != nil {
t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err) t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err)
} }
testDir := filepath.Join(tempDirPath, "temp_dir_for_testing_startupshutdown.go") testDir := filepath.Join(tempDirPath, "temp_dir_for_testing_startupshutdown")
defer func() {
// clean up after non-blocking startup function quits
time.Sleep(500 * time.Millisecond)
os.RemoveAll(testDir)
}()
osSenitiveTestDir := filepath.FromSlash(testDir) osSenitiveTestDir := filepath.FromSlash(testDir)
os.RemoveAll(osSenitiveTestDir) // start with a clean slate
exec.Command("rm", "-r", osSenitiveTestDir).Run() // removes osSenitiveTestDir from the OS's temp directory, if the osSenitiveTestDir already exists
tests := []struct { tests := []struct {
input string input string
...@@ -53,6 +55,5 @@ func TestStartup(t *testing.T) { ...@@ -53,6 +55,5 @@ func TestStartup(t *testing.T) {
if err != nil && !test.shouldRemoveErr { if err != nil && !test.shouldRemoveErr {
t.Errorf("Test %d recieved an error of:\n%v", i, err) t.Errorf("Test %d recieved an error of:\n%v", i, err)
} }
} }
} }
...@@ -11,12 +11,12 @@ import ( ...@@ -11,12 +11,12 @@ import (
// TLS sets up the TLS configuration (but does not activate Let's Encrypt; that is handled elsewhere). // TLS sets up the TLS configuration (but does not activate Let's Encrypt; that is handled elsewhere).
func TLS(c *Controller) (middleware.Middleware, error) { func TLS(c *Controller) (middleware.Middleware, error) {
if c.Port == "http" { if c.Scheme == "http" && c.Port != "80" {
c.TLS.Enabled = false c.TLS.Enabled = false
log.Printf("[WARNING] TLS disabled for %s://%s. To force TLS over the plaintext HTTP port, "+ log.Printf("[WARNING] TLS disabled for %s://%s. To force TLS over the plaintext HTTP port, "+
"specify port 80 explicitly (https://%s:80).", c.Port, c.Host, c.Host) "specify port 80 explicitly (https://%s:80).", c.Scheme, c.Address(), c.Host)
} else { } else {
c.TLS.Enabled = true // they had a tls directive, so assume it's on unless we confirm otherwise later c.TLS.Enabled = true
} }
for c.Next() { for c.Next() {
...@@ -32,18 +32,9 @@ func TLS(c *Controller) (middleware.Middleware, error) { ...@@ -32,18 +32,9 @@ func TLS(c *Controller) (middleware.Middleware, error) {
case 2: case 2:
c.TLS.Certificate = args[0] c.TLS.Certificate = args[0]
c.TLS.Key = args[1] c.TLS.Key = args[1]
// manual HTTPS configuration without port specified should be
// served on the HTTPS port; that is what user would expect, and
// makes it consistent with how the letsencrypt package works.
if c.Port == "" {
c.Port = "https"
}
default:
return nil, c.ArgErr()
} }
// Optional block // Optional block with extra parameters
for c.NextBlock() { for c.NextBlock() {
switch c.Val() { switch c.Val() {
case "protocols": case "protocols":
...@@ -74,6 +65,9 @@ func TLS(c *Controller) (middleware.Middleware, error) { ...@@ -74,6 +65,9 @@ func TLS(c *Controller) (middleware.Middleware, error) {
if len(c.TLS.ClientCerts) == 0 { if len(c.TLS.ClientCerts) == 0 {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
// TODO: Allow this? It's a bad idea to allow HTTP. If we do this, make sure invoking tls at all (even manually) also sets up a redirect if possible?
// case "allow_http":
// c.TLS.DisableHTTPRedir = true
default: default:
return nil, c.Errf("Unknown keyword '%s'", c.Val()) return nil, c.Errf("Unknown keyword '%s'", c.Val())
} }
...@@ -85,8 +79,9 @@ func TLS(c *Controller) (middleware.Middleware, error) { ...@@ -85,8 +79,9 @@ func TLS(c *Controller) (middleware.Middleware, error) {
return nil, nil return nil, nil
} }
// SetDefaultTLSParams sets the default TLS cipher suites, protocol versions and server preferences // SetDefaultTLSParams sets the default TLS cipher suites, protocol versions,
// of a server.Config if they were not previously set. // and server preferences of a server.Config if they were not previously set
// (it does not overwrite; only fills in missing values).
func SetDefaultTLSParams(c *server.Config) { func SetDefaultTLSParams(c *server.Config) {
// If no ciphers provided, use all that Caddy supports for the protocol // If no ciphers provided, use all that Caddy supports for the protocol
if len(c.TLS.Ciphers) == 0 { if len(c.TLS.Ciphers) == 0 {
...@@ -106,6 +101,11 @@ func SetDefaultTLSParams(c *server.Config) { ...@@ -106,6 +101,11 @@ func SetDefaultTLSParams(c *server.Config) {
// Prefer server cipher suites // Prefer server cipher suites
c.TLS.PreferServerCipherSuites = true c.TLS.PreferServerCipherSuites = true
// Default TLS port is 443; only use if port is not manually specified
if c.Port == "" {
c.Port = "443"
}
} }
// Map of supported protocols // Map of supported protocols
......
...@@ -64,11 +64,12 @@ func TestTLSParseBasic(t *testing.T) { ...@@ -64,11 +64,12 @@ func TestTLSParseBasic(t *testing.T) {
} }
func TestTLSParseIncompleteParams(t *testing.T) { func TestTLSParseIncompleteParams(t *testing.T) {
// This doesn't do anything useful but is allowed in case the user wants to be explicit
// about TLS being enabled...
c := NewTestController(`tls`) c := NewTestController(`tls`)
_, err := TLS(c) _, err := TLS(c)
if err == nil { if err != nil {
t.Errorf("Expected errors (first check), but no error returned") t.Errorf("Expected no error, but got %v", err)
} }
} }
...@@ -93,10 +94,39 @@ func TestTLSParseWithOptionalParams(t *testing.T) { ...@@ -93,10 +94,39 @@ func TestTLSParseWithOptionalParams(t *testing.T) {
} }
if len(c.TLS.Ciphers)-1 != 3 { if len(c.TLS.Ciphers)-1 != 3 {
t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)) t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1)
}
}
func TestTLSDefaultWithOptionalParams(t *testing.T) {
params := `tls {
ciphers RSA-3DES-EDE-CBC-SHA
}`
c := NewTestController(params)
_, err := TLS(c)
if err != nil {
t.Errorf("Expected no errors, got: %v", err)
}
if len(c.TLS.Ciphers)-1 != 1 {
t.Errorf("Expected 1 ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1)
} }
} }
// TODO: If we allow this... but probably not a good idea.
// func TestTLSDisableHTTPRedirect(t *testing.T) {
// c := NewTestController(`tls {
// allow_http
// }`)
// _, err := TLS(c)
// if err != nil {
// t.Errorf("Expected no error, but got %v", err)
// }
// if !c.TLS.DisableHTTPRedir {
// t.Error("Expected HTTP redirect to be disabled, but it wasn't")
// }
// }
func TestTLSParseWithWrongOptionalParams(t *testing.T) { func TestTLSParseWithWrongOptionalParams(t *testing.T) {
// Test protocols wrong params // Test protocols wrong params
params := `tls cert.crt cert.key { params := `tls cert.crt cert.key {
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"github.com/mholt/caddy/caddy" "github.com/mholt/caddy/caddy"
"github.com/mholt/caddy/caddy/letsencrypt" "github.com/mholt/caddy/caddy/letsencrypt"
"github.com/xenolf/lego/acme"
) )
var ( var (
...@@ -53,6 +54,7 @@ func main() { ...@@ -53,6 +54,7 @@ func main() {
caddy.AppName = appName caddy.AppName = appName
caddy.AppVersion = appVersion caddy.AppVersion = appVersion
acme.UserAgent = appName + "/" + appVersion
// set up process log before anything bad happens // set up process log before anything bad happens
switch logfile { switch logfile {
......
...@@ -40,8 +40,8 @@ func NewResponseFilterWriter(filters []ResponseFilter, gz *gzipResponseWriter) * ...@@ -40,8 +40,8 @@ func NewResponseFilterWriter(filters []ResponseFilter, gz *gzipResponseWriter) *
return &ResponseFilterWriter{filters: filters, gzipResponseWriter: gz} return &ResponseFilterWriter{filters: filters, gzipResponseWriter: gz}
} }
// Write wraps underlying WriteHeader method and compresses if filters // WriteHeader wraps underlying WriteHeader method and
// are satisfied. // compresses if filters are satisfied.
func (r *ResponseFilterWriter) WriteHeader(code int) { func (r *ResponseFilterWriter) WriteHeader(code int) {
// Determine if compression should be used or not. // Determine if compression should be used or not.
r.shouldCompress = true r.shouldCompress = true
......
...@@ -11,7 +11,7 @@ import ( ...@@ -11,7 +11,7 @@ import (
) )
func TestLengthFilter(t *testing.T) { func TestLengthFilter(t *testing.T) {
var filters []ResponseFilter = []ResponseFilter{ var filters = []ResponseFilter{
LengthFilter(100), LengthFilter(100),
LengthFilter(1000), LengthFilter(1000),
LengthFilter(0), LengthFilter(0),
......
...@@ -9,8 +9,8 @@ import ( ...@@ -9,8 +9,8 @@ import (
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
) )
// Operators
const ( const (
// Operators
Is = "is" Is = "is"
Not = "not" Not = "not"
Has = "has" Has = "has"
......
...@@ -13,12 +13,12 @@ import ( ...@@ -13,12 +13,12 @@ import (
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
) )
// RewriteResult is the result of a rewrite // Result is the result of a rewrite
type RewriteResult int type Result int
const ( const (
// RewriteIgnored is returned when rewrite is not done on request. // RewriteIgnored is returned when rewrite is not done on request.
RewriteIgnored RewriteResult = iota RewriteIgnored Result = iota
// RewriteDone is returned when rewrite is done on request. // RewriteDone is returned when rewrite is done on request.
RewriteDone RewriteDone
// RewriteStatus is returned when rewrite is not needed and status code should be set // RewriteStatus is returned when rewrite is not needed and status code should be set
...@@ -55,7 +55,7 @@ outer: ...@@ -55,7 +55,7 @@ outer:
// Rule describes an internal location rewrite rule. // Rule describes an internal location rewrite rule.
type Rule interface { type Rule interface {
// Rewrite rewrites the internal location of the current request. // Rewrite rewrites the internal location of the current request.
Rewrite(http.FileSystem, *http.Request) RewriteResult Rewrite(http.FileSystem, *http.Request) Result
} }
// SimpleRule is a simple rewrite rule. // SimpleRule is a simple rewrite rule.
...@@ -69,7 +69,7 @@ func NewSimpleRule(from, to string) SimpleRule { ...@@ -69,7 +69,7 @@ func NewSimpleRule(from, to string) SimpleRule {
} }
// Rewrite rewrites the internal location of the current request. // Rewrite rewrites the internal location of the current request.
func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) RewriteResult { func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
if s.From == r.URL.Path { if s.From == r.URL.Path {
// take note of this rewrite for internal use by fastcgi // take note of this rewrite for internal use by fastcgi
// all we need is the URI, not full URL // all we need is the URI, not full URL
...@@ -102,7 +102,7 @@ type ComplexRule struct { ...@@ -102,7 +102,7 @@ type ComplexRule struct {
*regexp.Regexp *regexp.Regexp
} }
// NewRegexpRule creates a new RegexpRule. It returns an error if regexp // NewComplexRule creates a new RegexpRule. It returns an error if regexp
// pattern (pattern) or extensions (ext) are invalid. // pattern (pattern) or extensions (ext) are invalid.
func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) { func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) {
// validate regexp if present // validate regexp if present
...@@ -136,7 +136,7 @@ func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If ...@@ -136,7 +136,7 @@ func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If
} }
// Rewrite rewrites the internal location of the current request. // Rewrite rewrites the internal location of the current request.
func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re RewriteResult) { func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result) {
rPath := req.URL.Path rPath := req.URL.Path
replacer := newReplacer(req) replacer := newReplacer(req)
......
...@@ -13,7 +13,7 @@ import ( ...@@ -13,7 +13,7 @@ import (
// To attempts rewrite. It attempts to rewrite to first valid path // To attempts rewrite. It attempts to rewrite to first valid path
// or the last path if none of the paths are valid. // or the last path if none of the paths are valid.
// Returns true if rewrite is successful and false otherwise. // Returns true if rewrite is successful and false otherwise.
func To(fs http.FileSystem, r *http.Request, to string, replacer middleware.Replacer) RewriteResult { func To(fs http.FileSystem, r *http.Request, to string, replacer middleware.Replacer) Result {
tos := strings.Fields(to) tos := strings.Fields(to)
// try each rewrite paths // try each rewrite paths
......
...@@ -17,6 +17,9 @@ type Config struct { ...@@ -17,6 +17,9 @@ type Config struct {
// The port to listen on // The port to listen on
Port string Port string
// The protocol (http/https) to serve with this config; only set if user explicitly specifies it
Scheme string
// The directory from which to serve files // The directory from which to serve files
Root string Root string
...@@ -66,6 +69,8 @@ type TLSConfig struct { ...@@ -66,6 +69,8 @@ type TLSConfig struct {
Certificate string Certificate string
Key string Key string
LetsEncryptEmail string LetsEncryptEmail string
Managed bool // will be set to true if config qualifies for automatic, managed TLS
//DisableHTTPRedir bool // TODO: not a good idea - should we really allow it?
OCSPStaple []byte OCSPStaple []byte
Ciphers []uint16 Ciphers []uint16
ProtocolMinVersion uint16 ProtocolMinVersion uint16
......
...@@ -18,8 +18,8 @@ func TestConfigAddress(t *testing.T) { ...@@ -18,8 +18,8 @@ func TestConfigAddress(t *testing.T) {
t.Errorf("Expected '%s' but got '%s'", expected, actual) t.Errorf("Expected '%s' but got '%s'", expected, actual)
} }
cfg = Config{Host: "::1", Port: "https"} cfg = Config{Host: "::1", Port: "443"}
if actual, expected := cfg.Address(), "[::1]:https"; expected != actual { if actual, expected := cfg.Address(), "[::1]:443"; expected != actual {
t.Errorf("Expected '%s' but got '%s'", expected, actual) t.Errorf("Expected '%s' but got '%s'", expected, actual)
} }
} }
...@@ -33,6 +33,7 @@ type Server struct { ...@@ -33,6 +33,7 @@ type Server struct {
httpWg sync.WaitGroup // used to wait on outstanding connections httpWg sync.WaitGroup // used to wait on outstanding connections
startChan chan struct{} // used to block until server is finished starting startChan chan struct{} // used to block until server is finished starting
connTimeout time.Duration // the maximum duration of a graceful shutdown connTimeout time.Duration // the maximum duration of a graceful shutdown
ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request
} }
// ListenerFile represents a listener. // ListenerFile represents a listener.
...@@ -41,6 +42,11 @@ type ListenerFile interface { ...@@ -41,6 +42,11 @@ type ListenerFile interface {
File() (*os.File, error) File() (*os.File, error)
} }
// OptionalCallback is a function that may or may not handle a request.
// It returns whether or not it handled the request. If it handled the
// request, it is presumed that no further request handling should occur.
type OptionalCallback func(http.ResponseWriter, *http.Request) bool
// New creates a new Server which will bind to addr and serve // New creates a new Server which will bind to addr and serve
// the sites/hosts configured in configs. Its listener will // the sites/hosts configured in configs. Its listener will
// gracefully close when the server is stopped which will take // gracefully close when the server is stopped which will take
...@@ -309,6 +315,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -309,6 +315,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
}() }()
w.Header().Set("Server", "Caddy")
// Execute the optional request callback if it exists
if s.ReqCallback != nil && s.ReqCallback(w, r) {
return
}
host, _, err := net.SplitHostPort(r.Host) host, _, err := net.SplitHostPort(r.Host)
if err != nil { if err != nil {
host = r.Host // oh well host = r.Host // oh well
...@@ -324,8 +337,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -324,8 +337,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
if vh, ok := s.vhosts[host]; ok { if vh, ok := s.vhosts[host]; ok {
w.Header().Set("Server", "Caddy")
status, _ := vh.stack.ServeHTTP(w, r) status, _ := vh.stack.ServeHTTP(w, r)
// Fallback error response in case error handling wasn't chained in // Fallback error response in case error handling wasn't chained in
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment