Commit b019501b authored by Matthew Holt's avatar Matthew Holt

Merge branch 'master' into telemetry

# Conflicts:
#	caddy/caddymain/run.go
#	caddyhttp/httpserver/plugin.go
#	caddytls/client.go
parents 8039a712 2922d09b
...@@ -103,7 +103,7 @@ While we really do value your requests and implement many of them, not all featu ...@@ -103,7 +103,7 @@ While we really do value your requests and implement many of them, not all featu
### Improving documentation ### Improving documentation
Caddy's documentation is available at [https://caddyserver.com/docs](https://caddyserver.com/docs). If you would like to make a fix to the docs, feel free to contribute at the [caddyserver/website](https://github.com/caddyserver/website) repository! Caddy's documentation is available at [https://caddyserver.com/docs](https://caddyserver.com/docs). If you would like to make a fix to the docs, please submit an issue here describing the change to make.
Note that plugin documentation is not hosted by the Caddy website, other than basic usage examples. They are managed by the individual plugin authors, and you will have to contact them to change their documentation. Note that plugin documentation is not hosted by the Caddy website, other than basic usage examples. They are managed by the individual plugin authors, and you will have to contact them to change their documentation.
......
<p align="center"> <p align="center">
<a href="https://caddyserver.com"><img src="https://user-images.githubusercontent.com/1128849/36137292-bebc223a-1051-11e8-9a81-4ea9054c96ac.png" alt="Caddy" width="400"></a> <a href="https://caddyserver.com"><img src="https://user-images.githubusercontent.com/1128849/36338535-05fb646a-136f-11e8-987b-e6901e717d5a.png" alt="Caddy" width="450"></a>
</p> </p>
<h3 align="center">Every Site on HTTPS <!-- Serve Confidently --></h3> <h3 align="center">Every Site on HTTPS <!-- Serve Confidently --></h3>
<p align="center">Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.</p> <p align="center">Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.</p>
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
--- ---
Caddy is fast, easy to use, and makes you more productive. Caddy is a **production-ready** open-source web server that is fast, easy to use, and makes you more productive.
Available for Windows, Mac, Linux, BSD, Solaris, and [Android](https://github.com/mholt/caddy/wiki/Running-Caddy-on-Android). Available for Windows, Mac, Linux, BSD, Solaris, and [Android](https://github.com/mholt/caddy/wiki/Running-Caddy-on-Android).
...@@ -41,31 +41,35 @@ Available for Windows, Mac, Linux, BSD, Solaris, and [Android](https://github.co ...@@ -41,31 +41,35 @@ Available for Windows, Mac, Linux, BSD, Solaris, and [Android](https://github.co
- **Automatic HTTPS** on by default (via [Let's Encrypt](https://letsencrypt.org)) - **Automatic HTTPS** on by default (via [Let's Encrypt](https://letsencrypt.org))
- **HTTP/2** by default - **HTTP/2** by default
- **Virtual hosting** so multiple sites just work - **Virtual hosting** so multiple sites just work
- Experimental **QUIC support** for those that like speed - Experimental **QUIC support** for cutting-edge transmissions
- TLS session ticket **key rotation** for more secure connections - TLS session ticket **key rotation** for more secure connections
- **Extensible with plugins** because a convenient web server is a helpful one - **Extensible with plugins** because a convenient web server is a helpful one
- **Runs anywhere** with **no external dependencies** (not even libc) - **Runs anywhere** with **no external dependencies** (not even libc)
There's way more, too! [See all features built into Caddy.](https://caddyserver.com/features) On top of all those, Caddy does even more with plugins: choose which plugins you want at [download](https://caddyserver.com/download). [See a more complete list of features built into Caddy.](https://caddyserver.com/features) On top of all those, Caddy does even more with plugins: choose which plugins you want at [download](https://caddyserver.com/download).
Altogether, Caddy can do things other web servers simply cannot do. Its features and plugins save you time and mistakes, and will cheer you up. Your Caddy instance takes care of the details for you!
## Install ## Install
Caddy binaries have no dependencies and are available for every platform. Get Caddy any one of these ways: Caddy binaries have no dependencies and are available for every platform. Get Caddy either of these ways:
- **[Download page](https://caddyserver.com/download)** (RECOMMENDED) allows you to customize your build in the browser
- **[Latest release](https://github.com/mholt/caddy/releases/latest)** for pre-built, vanilla binaries
- **[Download page](https://caddyserver.com/download)** allows you to
customize your build in the browser
- **[Latest release](https://github.com/mholt/caddy/releases/latest)** for
pre-built, vanilla binaries
## Build ## Build
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.9 or newer). Follow these instruction for fast building: To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.9 or newer). Follow these instruction for fast building:
- Get source `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds` - Get the source with `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
- Now `cd` to `$GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go` - Now `cd $GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`
Then make sure the `caddy` binary is in your PATH. Then make sure the `caddy` binary is in your PATH.
To build for other platforms, use build.go with the `--goos` and `--goarch` flags.
## Quick Start ## Quick Start
...@@ -85,7 +89,7 @@ If the `caddy` binary has permission to bind to low ports and your domain name's ...@@ -85,7 +89,7 @@ If the `caddy` binary has permission to bind to low ports and your domain name's
caddy -host example.com caddy -host example.com
``` ```
This command serves static files from the current directory over HTTPS. Certificates are automatically obtained and renewed for you! This command serves static files from the current directory over HTTPS. Certificates are automatically obtained and renewed for you! Caddy is also automatically configuring ports 80 and 443 for you, and redirecting HTTP to HTTPS. Cool, huh?
### Customizing your site ### Customizing your site
...@@ -115,7 +119,7 @@ To host multiple sites and do more with the Caddyfile, please see the [Caddyfile ...@@ -115,7 +119,7 @@ To host multiple sites and do more with the Caddyfile, please see the [Caddyfile
Sites with qualifying hostnames are served over [HTTPS by default](https://caddyserver.com/docs/automatic-https). Sites with qualifying hostnames are served over [HTTPS by default](https://caddyserver.com/docs/automatic-https).
Caddy has a command line interface. Run `caddy -h` to view basic help or see the [CLI documentation](https://caddyserver.com/docs/cli) for details. Caddy has a nice little command line interface. Run `caddy -h` to view basic help or see the [CLI documentation](https://caddyserver.com/docs/cli) for details.
## Running in Production ## Running in Production
...@@ -139,7 +143,7 @@ Please see our [contributing guidelines](https://github.com/mholt/caddy/blob/mas ...@@ -139,7 +143,7 @@ Please see our [contributing guidelines](https://github.com/mholt/caddy/blob/mas
We use GitHub issues and pull requests only for discussing bug reports and the development of specific changes. We welcome all other topics on the [forum](https://caddy.community)! We use GitHub issues and pull requests only for discussing bug reports and the development of specific changes. We welcome all other topics on the [forum](https://caddy.community)!
If you want to contribute to the documentation, please submit pull requests to [caddyserver/website](https://github.com/caddyserver/website). If you want to contribute to the documentation, please [submit an issue](https://github.com/mholt/caddy/issues/new) describing the change that should be made.
Thanks for making Caddy -- and the Web -- better! Thanks for making Caddy -- and the Web -- better!
...@@ -158,6 +162,6 @@ We thank them for their services. **If you want to help keep Caddy free, please ...@@ -158,6 +162,6 @@ We thank them for their services. **If you want to help keep Caddy free, please
Caddy was born out of the need for a "batteries-included" web server that runs anywhere and doesn't have to take its configuration with it. Caddy took inspiration from [spark](https://github.com/rif/spark), [nginx](https://github.com/nginx/nginx), lighttpd, Caddy was born out of the need for a "batteries-included" web server that runs anywhere and doesn't have to take its configuration with it. Caddy took inspiration from [spark](https://github.com/rif/spark), [nginx](https://github.com/nginx/nginx), lighttpd,
[Websocketd](https://github.com/joewalnes/websocketd) and [Vagrant](https://www.vagrantup.com/), which provides a pleasant mixture of features from each of them. [Websocketd](https://github.com/joewalnes/websocketd) and [Vagrant](https://www.vagrantup.com/), which provides a pleasant mixture of features from each of them.
**The name "Caddy":** The name of the software is "Caddy", not "Caddy Server" or "CaddyServer". Please call it "Caddy" or, if you wish to clarify, "the Caddy web server". See [brand guidelines](https://caddyserver.com/brand). **The name "Caddy" is trademarked:** The name of the software is "Caddy", not "Caddy Server" or "CaddyServer". Please call it "Caddy" or, if you wish to clarify, "the Caddy web server". See [brand guidelines](https://caddyserver.com/brand). Caddy is a registered trademark of Light Code Labs, LLC.
*Author on Twitter: [@mholt6](https://twitter.com/mholt6)* *Author on Twitter: [@mholt6](https://twitter.com/mholt6)*
...@@ -802,7 +802,7 @@ func startServers(serverList []Server, inst *Instance, restartFds map[string]res ...@@ -802,7 +802,7 @@ func startServers(serverList []Server, inst *Instance, restartFds map[string]res
continue continue
} }
if strings.Contains(err.Error(), "use of closed network connection") { if strings.Contains(err.Error(), "use of closed network connection") {
// this error is normal when closing the listener // this error is normal when closing the listener; see https://github.com/golang/go/issues/4373
continue continue
} }
log.Println(err) log.Println(err)
......
...@@ -31,7 +31,7 @@ import ( ...@@ -31,7 +31,7 @@ import (
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddytls" "github.com/mholt/caddy/caddytls"
"github.com/mholt/caddy/telemetry" "github.com/mholt/caddy/telemetry"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
_ "github.com/mholt/caddy/caddyhttp" // plug in the HTTP server type _ "github.com/mholt/caddy/caddyhttp" // plug in the HTTP server type
...@@ -43,7 +43,7 @@ func init() { ...@@ -43,7 +43,7 @@ func init() {
setVersion() setVersion()
flag.BoolVar(&caddytls.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement") flag.BoolVar(&caddytls.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement")
flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory") flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-v02.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory")
flag.BoolVar(&caddytls.DisableHTTPChallenge, "disable-http-challenge", caddytls.DisableHTTPChallenge, "Disable the ACME HTTP challenge") flag.BoolVar(&caddytls.DisableHTTPChallenge, "disable-http-challenge", caddytls.DisableHTTPChallenge, "Disable the ACME HTTP challenge")
flag.BoolVar(&caddytls.DisableTLSSNIChallenge, "disable-tls-sni-challenge", caddytls.DisableTLSSNIChallenge, "Disable the ACME TLS-SNI challenge") flag.BoolVar(&caddytls.DisableTLSSNIChallenge, "disable-tls-sni-challenge", caddytls.DisableTLSSNIChallenge, "Disable the ACME TLS-SNI challenge")
flag.StringVar(&disabledMetrics, "disabled-metrics", "", "Comma-separated list of telemetry metrics to disable") flag.StringVar(&disabledMetrics, "disabled-metrics", "", "Comma-separated list of telemetry metrics to disable")
......
...@@ -265,14 +265,19 @@ func (p *parser) doImport() error { ...@@ -265,14 +265,19 @@ func (p *parser) doImport() error {
} else { } else {
globPattern = importPattern globPattern = importPattern
} }
if strings.Count(globPattern, "*") > 1 || strings.Count(globPattern, "?") > 1 ||
(strings.Contains(globPattern, "[") && strings.Contains(globPattern, "]")) {
// See issue #2096 - a pattern with many glob expansions can hang for too long
return p.Errf("Glob pattern may only contain one wildcard (*), but has others: %s", globPattern)
}
matches, err = filepath.Glob(globPattern) matches, err = filepath.Glob(globPattern)
if err != nil { if err != nil {
return p.Errf("Failed to use import pattern %s: %v", importPattern, err) return p.Errf("Failed to use import pattern %s: %v", importPattern, err)
} }
if len(matches) == 0 { if len(matches) == 0 {
if strings.Contains(globPattern, "*") { if strings.ContainsAny(globPattern, "*?[]") {
log.Printf("[WARNING] No files matching import pattern: %s", importPattern) log.Printf("[WARNING] No files matching import glob pattern: %s", importPattern)
} else { } else {
return p.Errf("File to import not found: %s", importPattern) return p.Errf("File to import not found: %s", importPattern)
} }
...@@ -443,7 +448,7 @@ func replaceEnvReferences(s, refStart, refEnd string) string { ...@@ -443,7 +448,7 @@ func replaceEnvReferences(s, refStart, refEnd string) string {
index := strings.Index(s, refStart) index := strings.Index(s, refStart)
for index != -1 { for index != -1 {
endIndex := strings.Index(s, refEnd) endIndex := strings.Index(s, refEnd)
if endIndex != -1 { if endIndex > index+len(refStart) {
ref := s[index : endIndex+len(refEnd)] ref := s[index : endIndex+len(refEnd)]
s = strings.Replace(s, ref, os.Getenv(ref[len(refStart):len(ref)-len(refEnd)]), -1) s = strings.Replace(s, ref, os.Getenv(ref[len(refStart):len(ref)-len(refEnd)]), -1)
} else { } else {
......
...@@ -228,6 +228,17 @@ func TestParseOneAndImport(t *testing.T) { ...@@ -228,6 +228,17 @@ func TestParseOneAndImport(t *testing.T) {
{`""`, false, []string{}, map[string]int{}}, {`""`, false, []string{}, map[string]int{}},
{``, false, []string{}, map[string]int{}}, {``, false, []string{}, map[string]int{}},
// test cases found by fuzzing!
{`import }{$"`, true, []string{}, map[string]int{}},
{`import /*/*.txt`, true, []string{}, map[string]int{}},
{`import /???/?*?o`, true, []string{}, map[string]int{}},
{`import /??`, true, []string{}, map[string]int{}},
{`import /[a-z]`, true, []string{}, map[string]int{}},
{`import {$}`, true, []string{}, map[string]int{}},
{`import {%}`, true, []string{}, map[string]int{}},
{`import {$$}`, true, []string{}, map[string]int{}},
{`import {%%}`, true, []string{}, map[string]int{}},
} { } {
result, err := testParseOne(test.input) result, err := testParseOne(test.input)
......
...@@ -46,5 +46,4 @@ import ( ...@@ -46,5 +46,4 @@ import (
_ "github.com/mholt/caddy/caddyhttp/timeouts" _ "github.com/mholt/caddy/caddyhttp/timeouts"
_ "github.com/mholt/caddy/caddyhttp/websocket" _ "github.com/mholt/caddy/caddyhttp/websocket"
_ "github.com/mholt/caddy/onevent" _ "github.com/mholt/caddy/onevent"
_ "github.com/mholt/caddy/startupshutdown"
) )
...@@ -25,7 +25,7 @@ import ( ...@@ -25,7 +25,7 @@ import (
// ensure that the standard plugins are in fact plugged in // ensure that the standard plugins are in fact plugged in
// and registered properly; this is a quick/naive way to do it. // and registered properly; this is a quick/naive way to do it.
func TestStandardPlugins(t *testing.T) { func TestStandardPlugins(t *testing.T) {
numStandardPlugins := 33 // importing caddyhttp plugs in this many plugins numStandardPlugins := 31 // importing caddyhttp plugs in this many plugins
s := caddy.DescribePlugins() s := caddy.DescribePlugins()
if got, want := strings.Count(s, "\n"), numStandardPlugins+5; got != want { if got, want := strings.Count(s, "\n"), numStandardPlugins+5; got != want {
t.Errorf("Expected all standard plugins to be plugged in, got:\n%s", s) t.Errorf("Expected all standard plugins to be plugged in, got:\n%s", s)
......
...@@ -33,8 +33,11 @@ import ( ...@@ -33,8 +33,11 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"crypto/tls"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
"github.com/mholt/caddy/caddytls"
) )
// Handler is a middleware type that can handle requests as a FastCGI client. // Handler is a middleware type that can handle requests as a FastCGI client.
...@@ -323,6 +326,19 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string] ...@@ -323,6 +326,19 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
// Some web apps rely on knowing HTTPS or not // Some web apps rely on knowing HTTPS or not
if r.TLS != nil { if r.TLS != nil {
env["HTTPS"] = "on" env["HTTPS"] = "on"
// and pass the protocol details in a manner compatible with apache's mod_ssl
// (which is why they have a SSL_ prefix and not TLS_).
v, ok := tlsProtocolStringToMap[r.TLS.Version]
if ok {
env["SSL_PROTOCOL"] = v
}
// and pass the cipher suite in a manner compatible with apache's mod_ssl
for k, v := range caddytls.SupportedCiphersMap {
if v == r.TLS.CipherSuite {
env["SSL_CIPHER"] = k
break
}
}
} }
// Add env variables from config (with support for placeholders in values) // Add env variables from config (with support for placeholders in values)
...@@ -465,3 +481,11 @@ type LogError string ...@@ -465,3 +481,11 @@ type LogError string
func (l LogError) Error() string { func (l LogError) Error() string {
return string(l) return string(l)
} }
// Map of supported protocols to Apache ssl_mod format
// Note that these are slightly different from SupportedProtocols in caddytls/config.go's
var tlsProtocolStringToMap = map[uint16]string{
tls.VersionTLS10: "TLSv1",
tls.VersionTLS11: "TLSv1.1",
tls.VersionTLS12: "TLSv1.2",
}
...@@ -100,8 +100,8 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error { ...@@ -100,8 +100,8 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
} }
cfg.TLS.Enabled = true cfg.TLS.Enabled = true
cfg.Addr.Scheme = "https" cfg.Addr.Scheme = "https"
if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) { if loadCertificates && caddytls.HostQualifies(cfg.TLS.Hostname) {
_, err := cfg.TLS.CacheManagedCertificate(cfg.Addr.Host) _, err := cfg.TLS.CacheManagedCertificate(cfg.TLS.Hostname)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -44,6 +44,7 @@ type Logger struct { ...@@ -44,6 +44,7 @@ type Logger struct {
V4ipMask net.IPMask V4ipMask net.IPMask
V6ipMask net.IPMask V6ipMask net.IPMask
IPMaskExists bool IPMaskExists bool
Exceptions []string
} }
// NewTestLogger creates logger suitable for testing purposes // NewTestLogger creates logger suitable for testing purposes
...@@ -84,6 +85,17 @@ func (l Logger) MaskIP(ip string) string { ...@@ -84,6 +85,17 @@ func (l Logger) MaskIP(ip string) string {
} }
// ShouldLog returns true if the path is not exempted from
// being logged (i.e. it is not found in l.Exceptions).
func (l Logger) ShouldLog(path string) bool {
for _, exc := range l.Exceptions {
if Path(path).Matches(exc) {
return false
}
}
return true
}
// Attach binds logger Start and Close functions to // Attach binds logger Start and Close functions to
// controller's OnStartup and OnShutdown hooks. // controller's OnStartup and OnShutdown hooks.
func (l *Logger) Attach(controller *caddy.Controller) { func (l *Logger) Attach(controller *caddy.Controller) {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package httpserver package httpserver
import ( import (
"crypto/tls"
"flag" "flag"
"fmt" "fmt"
"log" "log"
...@@ -123,15 +124,17 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd ...@@ -123,15 +124,17 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
// For each address in each server block, make a new config // For each address in each server block, make a new config
for _, sb := range serverBlocks { for _, sb := range serverBlocks {
for _, key := range sb.Keys { for _, key := range sb.Keys {
key = strings.ToLower(key)
if _, dup := h.keysToSiteConfigs[key]; dup {
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
}
addr, err := standardizeAddress(key) addr, err := standardizeAddress(key)
if err != nil { if err != nil {
return serverBlocks, err return serverBlocks, err
} }
addr = addr.Normalize()
key = addr.Key()
if _, dup := h.keysToSiteConfigs[key]; dup {
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
}
// Fill in address components from command line so that middleware // Fill in address components from command line so that middleware
// have access to the correct information during setup // have access to the correct information during setup
if addr.Host == "" && Host != DefaultHost { if addr.Host == "" && Host != DefaultHost {
...@@ -146,7 +149,7 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd ...@@ -146,7 +149,7 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
if addrCopy.Port == "" && Port == DefaultPort { if addrCopy.Port == "" && Port == DefaultPort {
addrCopy.Port = Port addrCopy.Port = Port
} }
addrStr := strings.ToLower(addrCopy.String()) addrStr := addrCopy.String()
if otherSiteKey, dup := siteAddrs[addrStr]; dup { if otherSiteKey, dup := siteAddrs[addrStr]; dup {
err := fmt.Errorf("duplicate site address: %s", addrStr) err := fmt.Errorf("duplicate site address: %s", addrStr)
if (addrCopy.Host == Host && Host != DefaultHost) || if (addrCopy.Host == Host && Host != DefaultHost) ||
...@@ -218,6 +221,13 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) { ...@@ -218,6 +221,13 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
} }
} }
// Iterate each site configuration and make sure that:
// 1) TLS is disabled for explicitly-HTTP sites (necessary
// when an HTTP address shares a block containing tls)
// 2) if QUIC is enabled, TLS ClientAuth is not, because
// currently, QUIC does not support ClientAuth (TODO:
// revisit this when our QUIC implementation supports it)
// 3) if TLS ClientAuth is used, StrictHostMatching is on
var atLeastOneSiteLooksLikeProduction bool var atLeastOneSiteLooksLikeProduction bool
for _, cfg := range h.siteConfigs { for _, cfg := range h.siteConfigs {
// see if all the addresses (both sites and // see if all the addresses (both sites and
...@@ -254,6 +264,17 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) { ...@@ -254,6 +264,17 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
// instead of 443 because it doesn't know about TLS. // instead of 443 because it doesn't know about TLS.
cfg.Addr.Port = HTTPSPort cfg.Addr.Port = HTTPSPort
} }
if cfg.TLS.ClientAuth != tls.NoClientCert {
if QUIC {
return nil, fmt.Errorf("cannot enable TLS client authentication with QUIC, because QUIC does not yet support it")
}
// this must be enabled so that a client cannot connect
// using SNI for another site on this listener that
// does NOT require ClientAuth, and then send HTTP
// requests with the Host header of this site which DOES
// require client auth, thus bypassing it...
cfg.StrictHostMatching = true
}
} }
// we must map (group) each config to a bind address // we must map (group) each config to a bind address
...@@ -287,12 +308,22 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) { ...@@ -287,12 +308,22 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
return servers, nil return servers, nil
} }
// normalizedKey returns "normalized" key representation:
// scheme and host names are lowered, everything else stays the same
func normalizedKey(key string) string {
addr, err := standardizeAddress(key)
if err != nil {
return key
}
return addr.Normalize().Key()
}
// GetConfig gets the SiteConfig that corresponds to c. // GetConfig gets the SiteConfig that corresponds to c.
// If none exist (should only happen in tests), then a // If none exist (should only happen in tests), then a
// new, empty one will be created. // new, empty one will be created.
func GetConfig(c *caddy.Controller) *SiteConfig { func GetConfig(c *caddy.Controller) *SiteConfig {
ctx := c.Context().(*httpContext) ctx := c.Context().(*httpContext)
key := strings.ToLower(c.Key) key := normalizedKey(c.Key)
if cfg, ok := ctx.keysToSiteConfigs[key]; ok { if cfg, ok := ctx.keysToSiteConfigs[key]; ok {
return cfg return cfg
} }
...@@ -396,6 +427,43 @@ func (a Address) VHost() string { ...@@ -396,6 +427,43 @@ func (a Address) VHost() string {
return a.Original return a.Original
} }
// Normalize normalizes URL: turn scheme and host names into lower case
func (a Address) Normalize() Address {
path := a.Path
if !CaseSensitivePath {
path = strings.ToLower(path)
}
return Address{
Original: a.Original,
Scheme: strings.ToLower(a.Scheme),
Host: strings.ToLower(a.Host),
Port: a.Port,
Path: path,
}
}
// Key is similar to String, just replaces scheme and host values with modified values.
// Unlike String it doesn't add anything default (scheme, port, etc)
func (a Address) Key() string {
res := ""
if a.Scheme != "" {
res += a.Scheme + "://"
}
if a.Host != "" {
res += a.Host
}
if a.Port != "" {
if strings.HasPrefix(a.Original[len(res):], ":"+a.Port) {
// insert port only if the original has its own explicit port
res += ":" + a.Port
}
}
if a.Path != "" {
res += a.Path
}
return res
}
// standardizeAddress parses an address string into a structured format with separate // standardizeAddress parses an address string into a structured format with separate
// scheme, host, port, and path portions, as well as the original input string. // scheme, host, port, and path portions, as well as the original input string.
func standardizeAddress(str string) (Address, error) { func standardizeAddress(str string) (Address, error) {
...@@ -523,6 +591,7 @@ var directives = []string{ ...@@ -523,6 +591,7 @@ var directives = []string{
"startup", // TODO: Deprecate this directive "startup", // TODO: Deprecate this directive
"shutdown", // TODO: Deprecate this directive "shutdown", // TODO: Deprecate this directive
"on", "on",
"supervisor", // github.com/lucaslorentz/caddy-supervisor
"request_id", "request_id",
"realip", // github.com/captncraig/caddy-realip "realip", // github.com/captncraig/caddy-realip
"git", // github.com/abiosoft/caddy-git "git", // github.com/abiosoft/caddy-git
...@@ -538,13 +607,13 @@ var directives = []string{ ...@@ -538,13 +607,13 @@ var directives = []string{
"ext", "ext",
"gzip", "gzip",
"header", "header",
"geoip", // github.com/kodnaplakal/caddy-geoip
"errors", "errors",
"authz", // github.com/casbin/caddy-authz "authz", // github.com/casbin/caddy-authz
"filter", // github.com/echocat/caddy-filter "filter", // github.com/echocat/caddy-filter
"minify", // github.com/hacdias/caddy-minify "minify", // github.com/hacdias/caddy-minify
"ipfilter", // github.com/pyed/ipfilter "ipfilter", // github.com/pyed/ipfilter
"ratelimit", // github.com/xuqingfeng/caddy-rate-limit "ratelimit", // github.com/xuqingfeng/caddy-rate-limit
"search", // github.com/pedronasser/caddy-search
"expires", // github.com/epicagency/caddy-expires "expires", // github.com/epicagency/caddy-expires
"forwardproxy", // github.com/caddyserver/forwardproxy "forwardproxy", // github.com/caddyserver/forwardproxy
"basicauth", "basicauth",
......
...@@ -18,6 +18,10 @@ import ( ...@@ -18,6 +18,10 @@ import (
"strings" "strings"
"testing" "testing"
"sort"
"fmt"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile" "github.com/mholt/caddy/caddyfile"
) )
...@@ -147,7 +151,20 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) { ...@@ -147,7 +151,20 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Didn't expect an error, but got: %v", err) t.Fatalf("Didn't expect an error, but got: %v", err)
} }
addr := ctx.keysToSiteConfigs["localhost"].Addr localhostKey := "localhost"
item, ok := ctx.keysToSiteConfigs[localhostKey]
if !ok {
availableKeys := make(sort.StringSlice, len(ctx.keysToSiteConfigs))
i := 0
for key := range ctx.keysToSiteConfigs {
availableKeys[i] = fmt.Sprintf("'%s'", key)
i++
}
availableKeys.Sort()
t.Errorf("`%s` not found within registered keys, only these are available: %s", localhostKey, strings.Join(availableKeys, ", "))
return
}
addr := item.Addr
if addr.Port != Port { if addr.Port != Port {
t.Errorf("Expected the port on the address to be set, but got: %#v", addr) t.Errorf("Expected the port on the address to be set, but got: %#v", addr)
} }
...@@ -184,6 +201,64 @@ func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) { ...@@ -184,6 +201,64 @@ func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
} }
} }
func TestKeyNormalization(t *testing.T) {
originalCaseSensitivePath := CaseSensitivePath
defer func() {
CaseSensitivePath = originalCaseSensitivePath
}()
CaseSensitivePath = true
caseSensitiveData := []struct {
orig string
res string
}{
{
orig: "HTTP://A/ABCDEF",
res: "http://a/ABCDEF",
},
{
orig: "A/ABCDEF",
res: "a/ABCDEF",
},
{
orig: "A:2015/Port",
res: "a:2015/Port",
},
}
for _, item := range caseSensitiveData {
v := normalizedKey(item.orig)
if v != item.res {
t.Errorf("Normalization of `%s` with CaseSensitivePath option set to true must be equal to `%s`, got `%s` instead", item.orig, item.res, v)
}
}
CaseSensitivePath = false
caseInsensitiveData := []struct {
orig string
res string
}{
{
orig: "HTTP://A/ABCDEF",
res: "http://a/abcdef",
},
{
orig: "A/ABCDEF",
res: "a/abcdef",
},
{
orig: "A:2015/Port",
res: "a:2015/port",
},
}
for _, item := range caseInsensitiveData {
v := normalizedKey(item.orig)
if v != item.res {
t.Errorf("Normalization of `%s` with CaseSensitivePath option set to false must be equal to `%s`, got `%s` instead", item.orig, item.res, v)
}
}
}
func TestGetConfig(t *testing.T) { func TestGetConfig(t *testing.T) {
// case insensitivity for key // case insensitivity for key
con := caddy.NewTestController("http", "") con := caddy.NewTestController("http", "")
...@@ -201,6 +276,14 @@ func TestGetConfig(t *testing.T) { ...@@ -201,6 +276,14 @@ func TestGetConfig(t *testing.T) {
if cfg == cfg3 { if cfg == cfg3 {
t.Errorf("Expected different configs using when key is different; got %p and %p", cfg, cfg3) t.Errorf("Expected different configs using when key is different; got %p and %p", cfg, cfg3)
} }
con.Key = "foo/foobar"
cfg4 := GetConfig(con)
con.Key = "foo/Foobar"
cfg5 := GetConfig(con)
if cfg4 == cfg5 {
t.Errorf("Expected different cases in path to differentiate keys in general")
}
} }
func TestDirectivesList(t *testing.T) { func TestDirectivesList(t *testing.T) {
......
...@@ -29,6 +29,7 @@ import ( ...@@ -29,6 +29,7 @@ import (
"time" "time"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddytls"
) )
// requestReplacer is a strings.Replacer which is used to // requestReplacer is a strings.Replacer which is used to
...@@ -140,6 +141,14 @@ func canLogRequest(r *http.Request) bool { ...@@ -140,6 +141,14 @@ func canLogRequest(r *http.Request) bool {
return false return false
} }
// unescapeBraces finds escaped braces in s and returns
// a string with those braces unescaped.
func unescapeBraces(s string) string {
s = strings.Replace(s, "\\{", "{", -1)
s = strings.Replace(s, "\\}", "}", -1)
return s
}
// Replace performs a replacement of values on s and returns // Replace performs a replacement of values on s and returns
// the string with the replaced values. // the string with the replaced values.
func (r *replacer) Replace(s string) string { func (r *replacer) Replace(s string) string {
...@@ -149,32 +158,59 @@ func (r *replacer) Replace(s string) string { ...@@ -149,32 +158,59 @@ func (r *replacer) Replace(s string) string {
} }
result := "" result := ""
Placeholders: // process each placeholder in sequence
for { for {
idxStart := strings.Index(s, "{") var idxStart, idxEnd int
idxOffset := 0
for { // find first unescaped opening brace
searchSpace := s[idxOffset:]
idxStart = strings.Index(searchSpace, "{")
if idxStart == -1 { if idxStart == -1 {
// no placeholder anymore // no more placeholders
break Placeholders
}
if idxStart == 0 || searchSpace[idxStart-1] != '\\' {
// preceding character is not an escape
idxStart += idxOffset
break break
} }
idxEnd := strings.Index(s[idxStart:], "}") // the brace we found was escaped
// search the rest of the string next
idxOffset += idxStart + 1
}
idxOffset = 0
for { // find first unescaped closing brace
searchSpace := s[idxStart+idxOffset:]
idxEnd = strings.Index(searchSpace, "}")
if idxEnd == -1 { if idxEnd == -1 {
// unpaired placeholder // unpaired placeholder
break Placeholders
}
if idxEnd == 0 || searchSpace[idxEnd-1] != '\\' {
// preceding character is not an escape
idxEnd += idxOffset + idxStart
break break
} }
idxEnd += idxStart // the brace we found was escaped
// search the rest of the string next
idxOffset += idxEnd + 1
}
// get a replacement // get a replacement for the unescaped placeholder
placeholder := s[idxStart : idxEnd+1] placeholder := unescapeBraces(s[idxStart : idxEnd+1])
replacement := r.getSubstitution(placeholder) replacement := r.getSubstitution(placeholder)
// append prefix + replacement // append unescaped prefix + replacement
result += s[:idxStart] + replacement result += strings.TrimPrefix(unescapeBraces(s[:idxStart]), "\\") + replacement
// strip out scanned parts // strip out scanned parts
s = s[idxEnd+1:] s = s[idxEnd+1:]
} }
// append unscanned parts // append unscanned parts
return result + s return result + unescapeBraces(s)
} }
func roundDuration(d time.Duration) time.Duration { func roundDuration(d time.Duration) time.Duration {
...@@ -224,6 +260,16 @@ func (r *replacer) getSubstitution(key string) string { ...@@ -224,6 +260,16 @@ func (r *replacer) getSubstitution(key string) string {
} }
} }
} }
// search response headers then
if r.responseRecorder != nil && key[1] == '<' {
want := key[2 : len(key)-1]
for key, values := range r.responseRecorder.Header() {
// Header placeholders (case-insensitive)
if strings.EqualFold(key, want) {
return strings.Join(values, ",")
}
}
}
// next check for cookies // next check for cookies
if key[1] == '~' { if key[1] == '~' {
name := key[2 : len(key)-1] name := key[2 : len(key)-1]
...@@ -365,12 +411,46 @@ func (r *replacer) getSubstitution(key string) string { ...@@ -365,12 +411,46 @@ func (r *replacer) getSubstitution(key string) string {
} }
elapsedDuration := time.Since(r.responseRecorder.start) elapsedDuration := time.Since(r.responseRecorder.start)
return strconv.FormatInt(convertToMilliseconds(elapsedDuration), 10) return strconv.FormatInt(convertToMilliseconds(elapsedDuration), 10)
case "{tls_protocol}":
if r.request.TLS != nil {
for k, v := range caddytls.SupportedProtocols {
if v == r.request.TLS.Version {
return k
}
}
return "tls" // this should never happen, but guard in case
}
return r.emptyValue // because not using a secure channel
case "{tls_cipher}":
if r.request.TLS != nil {
for k, v := range caddytls.SupportedCiphersMap {
if v == r.request.TLS.CipherSuite {
return k
}
}
return "UNKNOWN" // this should never happen, but guard in case
}
return r.emptyValue
default:
// {labelN}
if strings.HasPrefix(key, "{label") {
nStr := key[6 : len(key)-1] // get the integer N in "{labelN}"
n, err := strconv.Atoi(nStr)
if err != nil || n < 1 {
return r.emptyValue
}
labels := strings.Split(r.request.Host, ".")
if n > len(labels) {
return r.emptyValue
}
return labels[n-1]
}
} }
return r.emptyValue return r.emptyValue
} }
//convertToMilliseconds returns the number of milliseconds in the given duration // convertToMilliseconds returns the number of milliseconds in the given duration
func convertToMilliseconds(d time.Duration) int64 { func convertToMilliseconds(d time.Duration) int64 {
return d.Nanoseconds() / 1e6 return d.Nanoseconds() / 1e6
} }
......
...@@ -53,7 +53,7 @@ func TestReplace(t *testing.T) { ...@@ -53,7 +53,7 @@ func TestReplace(t *testing.T) {
recordRequest := NewResponseRecorder(w) recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`) reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader) request, err := http.NewRequest("POST", "http://localhost.local/?foo=bar", reader)
if err != nil { if err != nil {
t.Fatalf("Failed to make request: %v", err) t.Fatalf("Failed to make request: %v", err)
} }
...@@ -67,6 +67,9 @@ func TestReplace(t *testing.T) { ...@@ -67,6 +67,9 @@ func TestReplace(t *testing.T) {
request.Header.Set("CustomAdd", "caddy") request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious") request.Header.Set("Cookie", "foo=bar; taste=delicious")
// add some respons headers
recordRequest.Header().Set("Custom", "CustomResponseHeader")
hostname, err := os.Hostname() hostname, err := os.Hostname()
if err != nil { if err != nil {
t.Fatalf("Failed to determine hostname: %v", err) t.Fatalf("Failed to determine hostname: %v", err)
...@@ -84,7 +87,7 @@ func TestReplace(t *testing.T) { ...@@ -84,7 +87,7 @@ func TestReplace(t *testing.T) {
expect string expect string
}{ }{
{"This hostname is {hostname}", "This hostname is " + hostname}, {"This hostname is {hostname}", "This hostname is " + hostname},
{"This host is {host}.", "This host is localhost."}, {"This host is {host}.", "This host is localhost.local."},
{"This request method is {method}.", "This request method is POST."}, {"This request method is {method}.", "This request method is POST."},
{"The response status is {status}.", "The response status is 200."}, {"The response status is {status}.", "The response status is 200."},
{"{when}", "02/Jan/2006:15:04:05 +0000"}, {"{when}", "02/Jan/2006:15:04:05 +0000"},
...@@ -92,10 +95,13 @@ func TestReplace(t *testing.T) { ...@@ -92,10 +95,13 @@ func TestReplace(t *testing.T) {
{"{when_unix}", "1136214252"}, {"{when_unix}", "1136214252"},
{"The Custom header is {>Custom}.", "The Custom header is foobarbaz."}, {"The Custom header is {>Custom}.", "The Custom header is foobarbaz."},
{"The CustomAdd header is {>CustomAdd}.", "The CustomAdd header is caddy."}, {"The CustomAdd header is {>CustomAdd}.", "The CustomAdd header is caddy."},
{"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost\\r\\n" + {"The Custom response header is {<Custom}.", "The Custom response header is CustomResponseHeader."},
{"Bad {>Custom placeholder", "Bad {>Custom placeholder"},
{"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost.local\\r\\n" +
"Cookie: foo=bar; taste=delicious\\r\\nCustom: foobarbaz\\r\\nCustomadd: caddy\\r\\n" + "Cookie: foo=bar; taste=delicious\\r\\nCustom: foobarbaz\\r\\nCustomadd: caddy\\r\\n" +
"Shorterval: 1\\r\\n\\r\\n."}, "Shorterval: 1\\r\\n\\r\\n."},
{"The cUsToM header is {>cUsToM}...", "The cUsToM header is foobarbaz..."}, {"The cUsToM header is {>cUsToM}...", "The cUsToM header is foobarbaz..."},
{"The cUsToM response header is {<CuSTom}.", "The cUsToM response header is CustomResponseHeader."},
{"The Non-Existent header is {>Non-Existent}.", "The Non-Existent header is -."}, {"The Non-Existent header is {>Non-Existent}.", "The Non-Existent header is -."},
{"Bad {host placeholder...", "Bad {host placeholder..."}, {"Bad {host placeholder...", "Bad {host placeholder..."},
{"Bad {>Custom placeholder", "Bad {>Custom placeholder"}, {"Bad {>Custom placeholder", "Bad {>Custom placeholder"},
...@@ -106,6 +112,9 @@ func TestReplace(t *testing.T) { ...@@ -106,6 +112,9 @@ func TestReplace(t *testing.T) {
{"Query string is {query}", "Query string is foo=bar"}, {"Query string is {query}", "Query string is foo=bar"},
{"Query string value for foo is {?foo}", "Query string value for foo is bar"}, {"Query string value for foo is {?foo}", "Query string value for foo is bar"},
{"Missing query string argument is {?missing}", "Missing query string argument is "}, {"Missing query string argument is {?missing}", "Missing query string argument is "},
{"{label1} {label2} {label3} {label4}", "localhost local - -"},
{"Label with missing number is {label} or {labelQQ}", "Label with missing number is - or -"},
{"\\{ 'hostname': '{hostname}' \\}", "{ 'hostname': '" + hostname + "' }"},
} }
for _, c := range testCases { for _, c := range testCases {
...@@ -138,6 +147,107 @@ func TestReplace(t *testing.T) { ...@@ -138,6 +147,107 @@ func TestReplace(t *testing.T) {
} }
} }
func BenchmarkReplace(b *testing.B) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
if err != nil {
b.Fatalf("Failed to make request: %v", err)
}
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
request = request.WithContext(ctx)
request.Header.Set("Custom", "foobarbaz")
request.Header.Set("ShorterVal", "1")
repl := NewReplacer(request, recordRequest, "-")
// add some headers after creating replacer
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
// add some respons headers
recordRequest.Header().Set("Custom", "CustomResponseHeader")
now = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
repl.Replace("This hostname is {hostname}")
}
}
func BenchmarkReplaceEscaped(b *testing.B) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
if err != nil {
b.Fatalf("Failed to make request: %v", err)
}
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
request = request.WithContext(ctx)
request.Header.Set("Custom", "foobarbaz")
request.Header.Set("ShorterVal", "1")
repl := NewReplacer(request, recordRequest, "-")
// add some headers after creating replacer
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
// add some respons headers
recordRequest.Header().Set("Custom", "CustomResponseHeader")
now = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
repl.Replace("\\{ 'hostname': '{hostname}' \\}")
}
}
func TestResponseRecorderNil(t *testing.T) {
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
request.Header.Set("Custom", "foobarbaz")
repl := NewReplacer(request, nil, "-")
// add some headers after creating replacer
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
old := now
now = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
}
defer func() {
now = old
}()
testCases := []struct {
template string
expect string
}{
{"The Custom response header is {<Custom}.", "The Custom response header is -."},
}
for _, c := range testCases {
if expected, actual := c.expect, repl.Replace(c.template); expected != actual {
t.Errorf("for template '%s', expected '%s', got '%s'", c.template, expected, actual)
}
}
}
func TestSet(t *testing.T) { func TestSet(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w) recordRequest := NewResponseRecorder(w)
......
...@@ -320,6 +320,9 @@ func (s *Server) Serve(ln net.Listener) error { ...@@ -320,6 +320,9 @@ func (s *Server) Serve(ln net.Listener) error {
} }
err := s.Server.Serve(ln) err := s.Server.Serve(ln)
if err == http.ErrServerClosed {
err = nil // not an error worth reporting since closing a server is intentional
}
if s.quicServer != nil { if s.quicServer != nil {
s.quicServer.Close() s.quicServer.Close()
} }
...@@ -421,19 +424,39 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -421,19 +424,39 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
r.URL = trimPathPrefix(r.URL, pathPrefix) r.URL = trimPathPrefix(r.URL, pathPrefix)
} }
// enforce strict host matching, which ensures that the SNI
// value (if any), matches the Host header; essential for
// sites that rely on TLS ClientAuth sharing a port with
// sites that do not - if mismatched, close the connection
if vhost.StrictHostMatching && r.TLS != nil &&
strings.ToLower(r.TLS.ServerName) != strings.ToLower(hostname) {
r.Close = true
log.Printf("[ERROR] %s - strict host matching: SNI (%s) and HTTP Host (%s) values differ",
vhost.Addr, r.TLS.ServerName, hostname)
return http.StatusForbidden, nil
}
return vhost.middlewareChain.ServeHTTP(w, r) return vhost.middlewareChain.ServeHTTP(w, r)
} }
func trimPathPrefix(u *url.URL, prefix string) *url.URL { func trimPathPrefix(u *url.URL, prefix string) *url.URL {
// We need to use URL.EscapedPath() when trimming the pathPrefix as // We need to use URL.EscapedPath() when trimming the pathPrefix as
// URL.Path is ambiguous about / or %2f - see docs. See #1927 // URL.Path is ambiguous about / or %2f - see docs. See #1927
trimmed := strings.TrimPrefix(u.EscapedPath(), prefix) trimmedPath := strings.TrimPrefix(u.EscapedPath(), prefix)
if !strings.HasPrefix(trimmed, "/") { if !strings.HasPrefix(trimmedPath, "/") {
trimmed = "/" + trimmed trimmedPath = "/" + trimmedPath
}
// After trimming path reconstruct uri string with Query before parsing
trimmedURI := trimmedPath
if u.RawQuery != "" || u.ForceQuery == true {
trimmedURI = trimmedPath + "?" + u.RawQuery
}
if u.Fragment != "" {
trimmedURI = trimmedURI + "#" + u.Fragment
} }
trimmedURL, err := url.Parse(trimmed) trimmedURL, err := url.Parse(trimmedURI)
if err != nil { if err != nil {
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmed, err) log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmedURI, err)
return u return u
} }
return trimmedURL return trimmedURL
......
...@@ -129,88 +129,108 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) { ...@@ -129,88 +129,108 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) {
func TestTrimPathPrefix(t *testing.T) { func TestTrimPathPrefix(t *testing.T) {
for i, pt := range []struct { for i, pt := range []struct {
path string url string
prefix string prefix string
expected string expected string
shouldFail bool shouldFail bool
}{ }{
{ {
path: "/my/path", url: "/my/path",
prefix: "/my", prefix: "/my",
expected: "/path", expected: "/path",
shouldFail: false, shouldFail: false,
}, },
{ {
path: "/my/%2f/path", url: "/my/%2f/path",
prefix: "/my", prefix: "/my",
expected: "/%2f/path", expected: "/%2f/path",
shouldFail: false, shouldFail: false,
}, },
{ {
path: "/my/path", url: "/my/path",
prefix: "/my/", prefix: "/my/",
expected: "/path", expected: "/path",
shouldFail: false, shouldFail: false,
}, },
{ {
path: "/my///path", url: "/my///path",
prefix: "/my", prefix: "/my",
expected: "/path", expected: "/path",
shouldFail: true, shouldFail: true,
}, },
{ {
path: "/my///path", url: "/my///path",
prefix: "/my", prefix: "/my",
expected: "///path", expected: "///path",
shouldFail: false, shouldFail: false,
}, },
{ {
path: "/my/path///slash", url: "/my/path///slash",
prefix: "/my", prefix: "/my",
expected: "/path///slash", expected: "/path///slash",
shouldFail: false, shouldFail: false,
}, },
{ {
path: "/my/%2f/path/%2f", url: "/my/%2f/path/%2f",
prefix: "/my", prefix: "/my",
expected: "/%2f/path/%2f", expected: "/%2f/path/%2f",
shouldFail: false, shouldFail: false,
}, { }, {
path: "/my/%20/path", url: "/my/%20/path",
prefix: "/my", prefix: "/my",
expected: "/%20/path", expected: "/%20/path",
shouldFail: false, shouldFail: false,
}, { }, {
path: "/path", url: "/path",
prefix: "", prefix: "",
expected: "/path", expected: "/path",
shouldFail: false, shouldFail: false,
}, { }, {
path: "/path/my/", url: "/path/my/",
prefix: "/my", prefix: "/my",
expected: "/path/my/", expected: "/path/my/",
shouldFail: false, shouldFail: false,
}, { }, {
path: "", url: "",
prefix: "/my", prefix: "/my",
expected: "/", expected: "/",
shouldFail: false, shouldFail: false,
}, { }, {
path: "/apath", url: "/apath",
prefix: "", prefix: "",
expected: "/apath", expected: "/apath",
shouldFail: false, shouldFail: false,
}, {
url: "/my/path/page.php?akey=value",
prefix: "/my",
expected: "/path/page.php?akey=value",
shouldFail: false,
}, {
url: "/my/path/page?key=value#fragment",
prefix: "/my",
expected: "/path/page?key=value#fragment",
shouldFail: false,
}, {
url: "/my/path/page#fragment",
prefix: "/my",
expected: "/path/page#fragment",
shouldFail: false,
}, {
url: "/my/apath?",
prefix: "/my",
expected: "/apath?",
shouldFail: false,
}, },
} { } {
u, _ := url.Parse(pt.path) u, _ := url.Parse(pt.url)
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.EscapedPath() != want { if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.String() != want {
if !pt.shouldFail { if !pt.shouldFail {
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.EscapedPath()) t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.String())
} }
} else if pt.shouldFail { } else if pt.shouldFail {
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.EscapedPath()) t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.String())
} }
} }
} }
......
...@@ -36,6 +36,16 @@ type SiteConfig struct { ...@@ -36,6 +36,16 @@ type SiteConfig struct {
// TLS configuration // TLS configuration
TLS *caddytls.Config TLS *caddytls.Config
// If true, the Host header in the HTTP request must
// match the SNI value in the TLS handshake (if any).
// This should be enabled whenever a site relies on
// TLS client authentication, for example; or any time
// you want to enforce that THIS site's TLS config
// is used and not the TLS config of any other site
// on the same listener. TODO: Check how relevant this
// is with TLS 1.3.
StrictHostMatching bool
// Uncompiled middleware stack // Uncompiled middleware stack
middleware []Middleware middleware []Middleware
......
...@@ -277,7 +277,7 @@ func TestHostname(t *testing.T) { ...@@ -277,7 +277,7 @@ func TestHostname(t *testing.T) {
// // Test 3 - ipv6 without port and brackets // // Test 3 - ipv6 without port and brackets
// {"2001:4860:4860::8888", "google-public-dns-a.google.com."}, // {"2001:4860:4860::8888", "google-public-dns-a.google.com."},
// Test 4 - no hostname available // Test 4 - no hostname available
{"1.1.1.1", "1.1.1.1"}, {"0.0.0.0", "0.0.0.0"},
} }
for i, test := range tests { for i, test := range tests {
......
...@@ -67,6 +67,10 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -67,6 +67,10 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// Write log entries // Write log entries
for _, e := range rule.Entries { for _, e := range rule.Entries {
// Check if there is an exception to prevent log being written
if !e.Log.ShouldLog(r.URL.Path) {
continue
}
// Mask IP Address // Mask IP Address
if e.Log.IPMaskExists { if e.Log.IPMaskExists {
...@@ -78,6 +82,7 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -78,6 +82,7 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
} }
} }
e.Log.Println(rep.Replace(e.Format)) e.Log.Println(rep.Replace(e.Format))
} }
return status, err return status, err
......
...@@ -177,3 +177,85 @@ func TestMultiEntries(t *testing.T) { ...@@ -177,3 +177,85 @@ func TestMultiEntries(t *testing.T) {
t.Errorf("Expected %q, but got %q", expect, got) t.Errorf("Expected %q, but got %q", expect, got)
} }
} }
func TestLogExcept(t *testing.T) {
tests := []struct {
LogRules []Rule
logPath string
shouldLog bool
}{
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/soup"},
},
Format: DefaultLogFormat,
}},
}}, `/soup`, false},
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/tart"},
},
Format: DefaultLogFormat,
}},
}}, `/soup`, true},
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/soup"},
},
Format: DefaultLogFormat,
}},
}}, `/tomatosoup`, true},
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/pie/"},
},
Format: DefaultLogFormat,
}},
// Check exception with a trailing slash does not match without
}}, `/pie`, true},
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/pie.php"},
},
Format: DefaultLogFormat,
}},
}}, `/pie`, true},
{[]Rule{{
PathScope: "/",
Entries: []*Entry{{
Log: &httpserver.Logger{
Exceptions: []string{"/pie"},
},
Format: DefaultLogFormat,
}},
// Check that a word without trailing slash will match a filename
}}, `/pie.php`, false},
}
for i, test := range tests {
for _, LogRule := range test.LogRules {
for _, e := range LogRule.Entries {
shouldLog := e.Log.ShouldLog(test.logPath)
if shouldLog != test.shouldLog {
t.Fatalf("Test %d expected shouldLog=%t but got shouldLog=%t,", i, test.shouldLog, shouldLog)
}
}
}
}
}
...@@ -44,7 +44,7 @@ func setup(c *caddy.Controller) error { ...@@ -44,7 +44,7 @@ func setup(c *caddy.Controller) error {
func logParse(c *caddy.Controller) ([]*Rule, error) { func logParse(c *caddy.Controller) ([]*Rule, error) {
var rules []*Rule var rules []*Rule
var logExceptions []string
for c.Next() { for c.Next() {
args := c.RemainingArgs() args := c.RemainingArgs()
...@@ -91,6 +91,12 @@ func logParse(c *caddy.Controller) ([]*Rule, error) { ...@@ -91,6 +91,12 @@ func logParse(c *caddy.Controller) ([]*Rule, error) {
} }
} else if what == "except" {
for i := 0; i < len(where); i++ {
logExceptions = append(logExceptions, where[i])
}
} else if httpserver.IsLogRollerSubdirective(what) { } else if httpserver.IsLogRollerSubdirective(what) {
if err := httpserver.ParseRoller(logRoller, what, where...); err != nil { if err := httpserver.ParseRoller(logRoller, what, where...); err != nil {
...@@ -133,6 +139,7 @@ func logParse(c *caddy.Controller) ([]*Rule, error) { ...@@ -133,6 +139,7 @@ func logParse(c *caddy.Controller) ([]*Rule, error) {
V4ipMask: ip4Mask, V4ipMask: ip4Mask,
V6ipMask: ip6Mask, V6ipMask: ip6Mask,
IPMaskExists: ipMaskExists, IPMaskExists: ipMaskExists,
Exceptions: logExceptions,
}, },
Format: format, Format: format,
}) })
......
...@@ -58,6 +58,10 @@ type Upstream interface { ...@@ -58,6 +58,10 @@ type Upstream interface {
// Gets the number of upstream hosts. // Gets the number of upstream hosts.
GetHostCount() int GetHostCount() int
// Gets how long to wait before timing out
// the request
GetTimeout() time.Duration
// Stops the upstream from proxying requests to shutdown goroutines cleanly. // Stops the upstream from proxying requests to shutdown goroutines cleanly.
Stop() error Stop() error
} }
...@@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -187,7 +191,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if nameURL, err := url.Parse(host.Name); err == nil { if nameURL, err := url.Parse(host.Name); err == nil {
outreq.Host = nameURL.Host outreq.Host = nameURL.Host
if proxy == nil { if proxy == nil {
proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost) proxy = NewSingleHostReverseProxy(nameURL,
host.WithoutPathPrefix,
http.DefaultMaxIdleConnsPerHost,
upstream.GetTimeout(),
)
} }
// use upstream credentials by default // use upstream credentials by default
......
This diff is collapsed.
...@@ -94,6 +94,10 @@ type ReverseProxy struct { ...@@ -94,6 +94,10 @@ type ReverseProxy struct {
// If zero, no periodic flushing is done. // If zero, no periodic flushing is done.
FlushInterval time.Duration FlushInterval time.Duration
// dialer is used when values from the
// defaultDialer need to be overridden per Proxy
dialer *net.Dialer
srvResolver srvResolver srvResolver srvResolver
} }
...@@ -103,13 +107,13 @@ type ReverseProxy struct { ...@@ -103,13 +107,13 @@ type ReverseProxy struct {
// What we need is just the path, so if "unix:/var/run/www.socket" // What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be // was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming. // "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) { func socketDial(hostName string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) { return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):]) return net.DialTimeout("unix", hostName[len("unix://"):], timeout)
} }
} }
func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) { func (rp *ReverseProxy) srvDialerFunc(locator string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
service := locator service := locator
if strings.HasPrefix(locator, "srv://") { if strings.HasPrefix(locator, "srv://") {
service = locator[6:] service = locator[6:]
...@@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) ...@@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)) return net.DialTimeout("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port), timeout)
} }
} }
...@@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string { ...@@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string {
// the target request will be for /base/dir. // the target request will be for /base/dir.
// Without logic: target's path is "/", incoming is "/api/messages", // Without logic: target's path is "/", incoming is "/api/messages",
// without is "/api", then the target request will be for /messages. // without is "/api", then the target request will be for /messages.
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy { func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int, timeout time.Duration) *ReverseProxy {
targetQuery := target.RawQuery targetQuery := target.RawQuery
director := func(req *http.Request) { director := func(req *http.Request) {
if target.Scheme == "unix" { if target.Scheme == "unix" {
...@@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -226,15 +230,21 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
} }
} }
dialer := *defaultDialer
if timeout != defaultDialer.Timeout {
dialer.Timeout = timeout
}
rp := &ReverseProxy{ rp := &ReverseProxy{
Director: director, Director: director,
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
srvResolver: net.DefaultResolver, srvResolver: net.DefaultResolver,
dialer: &dialer,
} }
if target.Scheme == "unix" { if target.Scheme == "unix" {
rp.Transport = &http.Transport{ rp.Transport = &http.Transport{
Dial: socketDial(target.String()), Dial: socketDial(target.String(), timeout),
} }
} else if target.Scheme == "quic" { } else if target.Scheme == "quic" {
rp.Transport = &h2quic.RoundTripper{ rp.Transport = &h2quic.RoundTripper{
...@@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -244,9 +254,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
}, },
} }
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") { } else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
dialFunc := defaultDialer.Dial dialFunc := rp.dialer.Dial
if strings.HasPrefix(target.Scheme, "srv") { if strings.HasPrefix(target.Scheme, "srv") {
dialFunc = rp.srvDialerFunc(target.String()) dialFunc = rp.srvDialerFunc(target.String(), timeout)
} }
transport := &http.Transport{ transport := &http.Transport{
...@@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() { ...@@ -275,7 +285,7 @@ func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil { if rp.Transport == nil {
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial, Dial: rp.dialer.Dial,
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout, TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
} }
...@@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, ...@@ -306,7 +316,9 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
if requestIsWebsocket(outreq) { if requestIsWebsocket(outreq) {
transport = newConnHijackerTransport(transport) transport = newConnHijackerTransport(transport)
} else if transport == nil { } else if transport == nil {
transport = http.DefaultTransport transport = &http.Transport{
Dial: rp.dialer.Dial,
}
} }
rp.Director(outreq) rp.Director(outreq)
...@@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, ...@@ -361,7 +373,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
} }
bufferPool.Put(hj.Replay) bufferPool.Put(hj.Replay)
} else { } else {
backendConn, err = net.Dial("tcp", outreq.URL.Host) backendConn, err = net.DialTimeout("tcp", outreq.URL.Host, rp.dialer.Timeout)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"testing" "testing"
"time"
) )
const ( const (
...@@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) { ...@@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) {
} }
port := uint16(pp) port := uint16(pp)
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost) rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost, 30*time.Second)
rp.srvResolver = testResolver{ rp.srvResolver = testResolver{
result: []*net.SRV{ result: []*net.SRV{
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1}, {Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
......
...@@ -49,6 +49,7 @@ type staticUpstream struct { ...@@ -49,6 +49,7 @@ type staticUpstream struct {
Hosts HostPool Hosts HostPool
Policy Policy Policy Policy
KeepAlive int KeepAlive int
Timeout time.Duration
FailTimeout time.Duration FailTimeout time.Duration
TryDuration time.Duration TryDuration time.Duration
TryInterval time.Duration TryInterval time.Duration
...@@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) ...@@ -92,6 +93,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
TryInterval: 250 * time.Millisecond, TryInterval: 250 * time.Millisecond,
MaxConns: 0, MaxConns: 0,
KeepAlive: http.DefaultMaxIdleConnsPerHost, KeepAlive: http.DefaultMaxIdleConnsPerHost,
Timeout: 30 * time.Second,
resolver: net.DefaultResolver, resolver: net.DefaultResolver,
} }
...@@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { ...@@ -225,7 +227,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
return nil, err return nil, err
} }
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive) uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive, u.Timeout)
if u.insecureSkipVerify { if u.insecureSkipVerify {
uh.ReverseProxy.UseInsecureTransport() uh.ReverseProxy.UseInsecureTransport()
} }
...@@ -431,9 +433,10 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { ...@@ -431,9 +433,10 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
} }
u.downstreamHeaders.Add(header, value) u.downstreamHeaders.Add(header, value)
case "transparent": case "transparent":
// Note: X-Forwarded-For header is always being appended for proxy connections
// See implementation of createUpstreamRequest in proxy.go
u.upstreamHeaders.Add("Host", "{host}") u.upstreamHeaders.Add("Host", "{host}")
u.upstreamHeaders.Add("X-Real-IP", "{remote}") u.upstreamHeaders.Add("X-Real-IP", "{remote}")
u.upstreamHeaders.Add("X-Forwarded-For", "{remote}")
u.upstreamHeaders.Add("X-Forwarded-Proto", "{scheme}") u.upstreamHeaders.Add("X-Forwarded-Proto", "{scheme}")
case "websocket": case "websocket":
u.upstreamHeaders.Add("Connection", "{>Connection}") u.upstreamHeaders.Add("Connection", "{>Connection}")
...@@ -463,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { ...@@ -463,6 +466,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
return c.ArgErr() return c.ArgErr()
} }
u.KeepAlive = n u.KeepAlive = n
case "timeout":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return c.Errf("unable to parse timeout duration '%s'", c.Val())
}
u.Timeout = dur
default: default:
return c.Errf("unknown property '%s'", c.Val()) return c.Errf("unknown property '%s'", c.Val())
} }
...@@ -618,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration { ...@@ -618,6 +630,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration {
return u.TryInterval return u.TryInterval
} }
// GetTimeout returns u.Timeout.
func (u *staticUpstream) GetTimeout() time.Duration {
return u.Timeout
}
func (u *staticUpstream) GetHostCount() int { func (u *staticUpstream) GetHostCount() int {
return len(u.Hosts) return len(u.Hosts)
} }
......
...@@ -282,7 +282,8 @@ func TestStop(t *testing.T) { ...@@ -282,7 +282,8 @@ func TestStop(t *testing.T) {
} }
} }
func TestParseBlock(t *testing.T) { func TestParseBlockTransparent(t *testing.T) {
// tests for transparent proxy presets
r, _ := http.NewRequest("GET", "/", nil) r, _ := http.NewRequest("GET", "/", nil)
tests := []struct { tests := []struct {
config string config string
...@@ -316,6 +317,10 @@ func TestParseBlock(t *testing.T) { ...@@ -316,6 +317,10 @@ func TestParseBlock(t *testing.T) {
if _, ok := headers["X-Forwarded-Proto"]; !ok { if _, ok := headers["X-Forwarded-Proto"]; !ok {
t.Errorf("Test %d: Could not find the X-Forwarded-Proto header", i+1) t.Errorf("Test %d: Could not find the X-Forwarded-Proto header", i+1)
} }
if _, ok := headers["X-Forwarded-For"]; ok {
t.Errorf("Test %d: Found unexpected X-Forwarded-For header", i+1)
}
} }
} }
} }
......
...@@ -63,22 +63,38 @@ type Rule interface { ...@@ -63,22 +63,38 @@ type Rule interface {
// SimpleRule is a simple rewrite rule. // SimpleRule is a simple rewrite rule.
type SimpleRule struct { type SimpleRule struct {
From, To string Regexp *regexp.Regexp
To string
Negate bool
} }
// NewSimpleRule creates a new Simple Rule // NewSimpleRule creates a new Simple Rule
func NewSimpleRule(from, to string) SimpleRule { func NewSimpleRule(from, to string, negate bool) (*SimpleRule, error) {
return SimpleRule{from, to} r, err := regexp.Compile(from)
if err != nil {
return nil, err
}
return &SimpleRule{
Regexp: r,
To: to,
Negate: negate,
}, nil
} }
// BasePath satisfies httpserver.Config // BasePath satisfies httpserver.Config
func (s SimpleRule) BasePath() string { return s.From } func (s SimpleRule) BasePath() string { return "/" }
// Match satisfies httpserver.Config // Match satisfies httpserver.Config
func (s SimpleRule) Match(r *http.Request) bool { return s.From == r.URL.Path } func (s *SimpleRule) Match(r *http.Request) bool {
matches := regexpMatches(s.Regexp, "/", r.URL.Path)
if s.Negate {
return len(matches) == 0
}
return len(matches) > 0
}
// Rewrite rewrites the internal location of the current request. // Rewrite rewrites the internal location of the current request.
func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result { func (s *SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
// attempt rewrite // attempt rewrite
return To(fs, r, s.To, newReplacer(r)) return To(fs, r, s.To, newReplacer(r))
...@@ -165,7 +181,7 @@ func (r ComplexRule) Match(req *http.Request) bool { ...@@ -165,7 +181,7 @@ func (r ComplexRule) Match(req *http.Request) bool {
return true return true
} }
// otherwise validate regex // otherwise validate regex
return r.regexpMatches(req.URL.Path) != nil return regexpMatches(r.Regexp, r.Base, req.URL.Path) != nil
} }
// Rewrite rewrites the internal location of the current request. // Rewrite rewrites the internal location of the current request.
...@@ -174,7 +190,7 @@ func (r ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result) ...@@ -174,7 +190,7 @@ func (r ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result)
// validate regexp if present // validate regexp if present
if r.Regexp != nil { if r.Regexp != nil {
matches := r.regexpMatches(req.URL.Path) matches := regexpMatches(r.Regexp, r.Base, req.URL.Path)
switch len(matches) { switch len(matches) {
case 0: case 0:
// no match // no match
...@@ -230,14 +246,14 @@ func (r ComplexRule) matchExt(rPath string) bool { ...@@ -230,14 +246,14 @@ func (r ComplexRule) matchExt(rPath string) bool {
return !mustUse return !mustUse
} }
func (r ComplexRule) regexpMatches(rPath string) []string { func regexpMatches(regexp *regexp.Regexp, base, rPath string) []string {
if r.Regexp != nil { if regexp != nil {
// include trailing slash in regexp if present // include trailing slash in regexp if present
start := len(r.Base) start := len(base)
if strings.HasSuffix(r.Base, "/") { if strings.HasSuffix(base, "/") {
start-- start--
} }
return r.Regexp.FindStringSubmatch(rPath[start:]) return regexp.FindStringSubmatch(rPath[start:])
} }
return nil return nil
} }
......
...@@ -29,9 +29,9 @@ func TestRewrite(t *testing.T) { ...@@ -29,9 +29,9 @@ func TestRewrite(t *testing.T) {
rw := Rewrite{ rw := Rewrite{
Next: httpserver.HandlerFunc(urlPrinter), Next: httpserver.HandlerFunc(urlPrinter),
Rules: []httpserver.HandlerConfig{ Rules: []httpserver.HandlerConfig{
NewSimpleRule("/from", "/to"), newSimpleRule(t, "^/from$", "/to"),
NewSimpleRule("/a", "/b"), newSimpleRule(t, "^/a$", "/b"),
NewSimpleRule("/b", "/b{uri}"), newSimpleRule(t, "^/b$", "/b{uri}"),
}, },
FileSys: http.Dir("."), FileSys: http.Dir("."),
} }
...@@ -131,6 +131,45 @@ func TestRewrite(t *testing.T) { ...@@ -131,6 +131,45 @@ func TestRewrite(t *testing.T) {
} }
} }
// TestWordpress is a test for wordpress usecase.
func TestWordpress(t *testing.T) {
rw := Rewrite{
Next: httpserver.HandlerFunc(urlPrinter),
Rules: []httpserver.HandlerConfig{
// both rules are same, thanks to Go regexp (confusion).
newSimpleRule(t, "^/wp-admin", "{path} {path}/ /index.php?{query}", true),
newSimpleRule(t, "^\\/wp-admin", "{path} {path}/ /index.php?{query}", true),
},
FileSys: http.Dir("."),
}
tests := []struct {
from string
expectedTo string
}{
{"/wp-admin", "/wp-admin"},
{"/wp-admin/login.php", "/wp-admin/login.php"},
{"/not-wp-admin/login.php?not=admin", "/index.php?not=admin"},
{"/loophole", "/index.php"},
{"/user?name=john", "/index.php?name=john"},
}
for i, test := range tests {
req, err := http.NewRequest("GET", test.from, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
ctx := context.WithValue(req.Context(), httpserver.OriginalURLCtxKey, *req.URL)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
rw.ServeHTTP(rec, req)
if got, want := rec.Body.String(), test.expectedTo; got != want {
t.Errorf("Test %d: Expected URL to be '%s' but was '%s'", i, want, got)
}
}
}
func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) { func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprint(w, r.URL.String()) fmt.Fprint(w, r.URL.String())
return 0, nil return 0, nil
......
...@@ -58,6 +58,7 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) { ...@@ -58,6 +58,7 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) {
var base = "/" var base = "/"
var pattern, to string var pattern, to string
var ext []string var ext []string
var negate bool
args := c.RemainingArgs() args := c.RemainingArgs()
...@@ -111,7 +112,14 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) { ...@@ -111,7 +112,14 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) {
// the only unhandled case is 2 and above // the only unhandled case is 2 and above
default: default:
rule = NewSimpleRule(args[0], strings.Join(args[1:], " ")) if args[0] == "not" {
negate = true
args = args[1:]
}
rule, err = NewSimpleRule(args[0], strings.Join(args[1:], " "), negate)
if err != nil {
return nil, err
}
rules = append(rules, rule) rules = append(rules, rule)
} }
......
...@@ -50,6 +50,19 @@ func TestSetup(t *testing.T) { ...@@ -50,6 +50,19 @@ func TestSetup(t *testing.T) {
} }
} }
// newSimpleRule is convenience test function for SimpleRule.
func newSimpleRule(t *testing.T, from, to string, negate ...bool) Rule {
var n bool
if len(negate) > 0 {
n = negate[0]
}
rule, err := NewSimpleRule(from, to, n)
if err != nil {
t.Fatal(err)
}
return rule
}
func TestRewriteParse(t *testing.T) { func TestRewriteParse(t *testing.T) {
simpleTests := []struct { simpleTests := []struct {
input string input string
...@@ -57,17 +70,20 @@ func TestRewriteParse(t *testing.T) { ...@@ -57,17 +70,20 @@ func TestRewriteParse(t *testing.T) {
expected []Rule expected []Rule
}{ }{
{`rewrite /from /to`, false, []Rule{ {`rewrite /from /to`, false, []Rule{
SimpleRule{From: "/from", To: "/to"}, newSimpleRule(t, "/from", "/to"),
}}, }},
{`rewrite /from /to {`rewrite /from /to
rewrite a b`, false, []Rule{ rewrite a b`, false, []Rule{
SimpleRule{From: "/from", To: "/to"}, newSimpleRule(t, "/from", "/to"),
SimpleRule{From: "a", To: "b"}, newSimpleRule(t, "a", "b"),
}}, }},
{`rewrite a`, true, []Rule{}}, {`rewrite a`, true, []Rule{}},
{`rewrite`, true, []Rule{}}, {`rewrite`, true, []Rule{}},
{`rewrite a b c`, false, []Rule{ {`rewrite a b c`, false, []Rule{
SimpleRule{From: "a", To: "b c"}, newSimpleRule(t, "a", "b c"),
}},
{`rewrite not a b c`, false, []Rule{
newSimpleRule(t, "a", "b c", true),
}}, }},
} }
...@@ -88,17 +104,22 @@ func TestRewriteParse(t *testing.T) { ...@@ -88,17 +104,22 @@ func TestRewriteParse(t *testing.T) {
} }
for j, e := range test.expected { for j, e := range test.expected {
actualRule := actual[j].(SimpleRule) actualRule := actual[j].(*SimpleRule)
expectedRule := e.(SimpleRule) expectedRule := e.(*SimpleRule)
if actualRule.From != expectedRule.From { if actualRule.Regexp.String() != expectedRule.Regexp.String() {
t.Errorf("Test %d, rule %d: Expected From=%s, got %s", t.Errorf("Test %d, rule %d: Expected From=%s, got %s",
i, j, expectedRule.From, actualRule.From) i, j, expectedRule.Regexp.String(), actualRule.Regexp.String())
} }
if actualRule.To != expectedRule.To { if actualRule.To != expectedRule.To {
t.Errorf("Test %d, rule %d: Expected To=%s, got %s", t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
i, j, expectedRule.To, actualRule.To) i, j, expectedRule.Regexp.String(), actualRule.Regexp.String())
}
if actualRule.Negate != expectedRule.Negate {
t.Errorf("Test %d, rule %d: Expected Negate=%v, got %v",
i, j, expectedRule.Negate, actualRule.Negate)
} }
} }
} }
......
...@@ -265,21 +265,21 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error { ...@@ -265,21 +265,21 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error {
return err return err
} }
if leaf.Subject.CommonName != "" { if leaf.Subject.CommonName != "" { // TODO: CommonName is deprecated
cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)} cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)}
} }
for _, name := range leaf.DNSNames { for _, name := range leaf.DNSNames {
if name != leaf.Subject.CommonName { if name != leaf.Subject.CommonName { // TODO: CommonName is deprecated
cert.Names = append(cert.Names, strings.ToLower(name)) cert.Names = append(cert.Names, strings.ToLower(name))
} }
} }
for _, ip := range leaf.IPAddresses { for _, ip := range leaf.IPAddresses {
if ipStr := ip.String(); ipStr != leaf.Subject.CommonName { if ipStr := ip.String(); ipStr != leaf.Subject.CommonName { // TODO: CommonName is deprecated
cert.Names = append(cert.Names, strings.ToLower(ipStr)) cert.Names = append(cert.Names, strings.ToLower(ipStr))
} }
} }
for _, email := range leaf.EmailAddresses { for _, email := range leaf.EmailAddresses {
if email != leaf.Subject.CommonName { if email != leaf.Subject.CommonName { // TODO: CommonName is deprecated
cert.Names = append(cert.Names, strings.ToLower(email)) cert.Names = append(cert.Names, strings.ToLower(email))
} }
} }
......
...@@ -43,10 +43,11 @@ func TestUnexportedGetCertificate(t *testing.T) { ...@@ -43,10 +43,11 @@ func TestUnexportedGetCertificate(t *testing.T) {
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
} }
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert) // TODO: Re-implement this behavior when I'm not in the middle of upgrading for ACMEv2 support. :) (it was reverted in #2037)
if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted { // // When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert) // if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
} // t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
// }
// When no certificate matches and SNI is NOT provided, a random is returned // When no certificate matches and SNI is NOT provided, a random is returned
if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted { if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted {
......
...@@ -27,7 +27,7 @@ import ( ...@@ -27,7 +27,7 @@ import (
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/telemetry" "github.com/mholt/caddy/telemetry"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
// acmeMu ensures that only one ACME challenge occurs at a time. // acmeMu ensures that only one ACME challenge occurs at a time.
...@@ -90,26 +90,21 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -90,26 +90,21 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
// If not registered, the user must register an account with the CA // If not registered, the user must register an account with the CA
// and agree to terms // and agree to terms
if leUser.Registration == nil { if leUser.Registration == nil {
reg, err := client.Register()
if err != nil {
return nil, errors.New("registration error: " + err.Error())
}
leUser.Registration = reg
if allowPrompts { // can't prompt a user who isn't there if allowPrompts { // can't prompt a user who isn't there
if !Agreed && reg.TosURL == "" { termsURL := client.GetToSURL()
Agreed = promptUserAgreement(saURL, false) // TODO - latest URL if !Agreed && termsURL != "" {
Agreed = askUserAgreement(client.GetToSURL())
} }
if !Agreed && reg.TosURL == "" { if !Agreed && termsURL != "" {
return nil, errors.New("user must agree to terms") return nil, errors.New("user must agree to CA terms (use -agree flag)")
} }
} }
err = client.AgreeToTOS() reg, err := client.Register(Agreed)
if err != nil { if err != nil {
saveUser(storage, leUser) // Might as well try, right? return nil, errors.New("registration error: " + err.Error())
return nil, errors.New("error agreeing to terms: " + err.Error())
} }
leUser.Registration = reg
// save user to the file system // save user to the file system
err = saveUser(storage, leUser) err = saveUser(storage, leUser)
...@@ -137,38 +132,57 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -137,38 +132,57 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
useHTTPPort = DefaultHTTPAlternatePort useHTTPPort = DefaultHTTPAlternatePort
} }
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// See which port TLS-SNI challenges will be accomplished on // See which port TLS-SNI challenges will be accomplished on
useTLSSNIPort := TLSSNIChallengePort // useTLSSNIPort := TLSSNIChallengePort
if config.AltTLSSNIPort != "" { // if config.AltTLSSNIPort != "" {
useTLSSNIPort = config.AltTLSSNIPort // useTLSSNIPort = config.AltTLSSNIPort
} // }
// err := c.acmeClient.SetTLSAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort))
// if err != nil {
// return nil, err
// }
// if using file storage, we can distribute the HTTP challenge across
// all instances sharing the acme folder; either way, we must still set
// the address for the default HTTP provider server
var useDistributedHTTPSolver bool
if storage, err := c.config.StorageFor(c.config.CAUrl); err == nil {
if _, ok := storage.(*FileStorage); ok {
useDistributedHTTPSolver = true
}
}
if useDistributedHTTPSolver {
c.acmeClient.SetChallengeProvider(acme.HTTP01, distributedHTTPSolver{
// being careful to respect user's listener bind preferences
httpProviderServer: acme.NewHTTPProviderServer(config.ListenHost, useHTTPPort),
})
} else {
// Always respect user's bind preferences by using config.ListenHost. // Always respect user's bind preferences by using config.ListenHost.
// NOTE(Sep'16): At time of writing, SetHTTPAddress() and SetTLSAddress() // NOTE(Sep'16): At time of writing, SetHTTPAddress() and SetTLSAddress()
// must be called before SetChallengeProvider(), since they reset the // must be called before SetChallengeProvider() (see above), since they reset
// challenge provider back to the default one! // the challenge provider back to the default one! (still true in March 2018)
err := c.acmeClient.SetHTTPAddress(net.JoinHostPort(config.ListenHost, useHTTPPort)) err := c.acmeClient.SetHTTPAddress(net.JoinHostPort(config.ListenHost, useHTTPPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = c.acmeClient.SetTLSAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort))
if err != nil {
return nil, err
} }
// TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// See if TLS challenge needs to be handled by our own facilities // See if TLS challenge needs to be handled by our own facilities
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) { // if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache}) // c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
} // }
// Disable any challenges that should not be used // Disable any challenges that should not be used
var disabledChallenges []acme.Challenge var disabledChallenges []acme.Challenge
if DisableHTTPChallenge { if DisableHTTPChallenge {
disabledChallenges = append(disabledChallenges, acme.HTTP01) disabledChallenges = append(disabledChallenges, acme.HTTP01)
} }
if DisableTLSSNIChallenge { // TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
disabledChallenges = append(disabledChallenges, acme.TLSSNI01) // if DisableTLSSNIChallenge {
} // disabledChallenges = append(disabledChallenges, acme.TLSSNI01)
// }
if len(disabledChallenges) > 0 { if len(disabledChallenges) > 0 {
c.acmeClient.ExcludeChallenges(disabledChallenges) c.acmeClient.ExcludeChallenges(disabledChallenges)
} }
...@@ -189,7 +203,9 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) ...@@ -189,7 +203,9 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
} }
// Use the DNS challenge exclusively // Use the DNS challenge exclusively
c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSSNI01}) // TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSSNI01})
c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01})
c.acmeClient.SetChallengeProvider(acme.DNS01, prov) c.acmeClient.SetChallengeProvider(acme.DNS01, prov)
} }
...@@ -222,41 +238,31 @@ func (c *ACMEClient) Obtain(name string) error { ...@@ -222,41 +238,31 @@ func (c *ACMEClient) Obtain(name string) error {
} }
}() }()
Attempts:
for attempts := 0; attempts < 2; attempts++ { for attempts := 0; attempts < 2; attempts++ {
namesObtaining.Add([]string{name}) namesObtaining.Add([]string{name})
acmeMu.Lock() acmeMu.Lock()
certificate, failures := c.acmeClient.ObtainCertificate([]string{name}, true, nil, c.config.MustStaple) certificate, err := c.acmeClient.ObtainCertificate([]string{name}, true, nil, c.config.MustStaple)
acmeMu.Unlock() acmeMu.Unlock()
namesObtaining.Remove([]string{name}) namesObtaining.Remove([]string{name})
if len(failures) > 0 { if err != nil {
// Error - try to fix it or report it to the user and abort // for a certain kind of error, we can enumerate the error per-domain
var errMsg string // we'll combine all the failures into a single error message if failures, ok := err.(acme.ObtainError); ok && len(failures) > 0 {
var promptedForAgreement bool // only prompt user for agreement at most once var errMsg string // combine all the failures into a single error message
for errDomain, obtainErr := range failures { for errDomain, obtainErr := range failures {
if obtainErr == nil { if obtainErr == nil {
continue continue
} }
if tosErr, ok := obtainErr.(acme.TOSError); ok { errMsg += fmt.Sprintf("[%s] failed to get certificate: %v\n", errDomain, obtainErr)
// Terms of Service agreement error; we can probably deal with this
if !Agreed && !promptedForAgreement && c.AllowPrompts {
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
promptedForAgreement = true
}
if Agreed || !c.AllowPrompts {
err := c.acmeClient.AgreeToTOS()
if err != nil {
return errors.New("error agreeing to updated terms: " + err.Error())
}
continue Attempts
} }
return errors.New(errMsg)
} }
// If user did not agree or it was any other kind of error, just append to the list of errors return fmt.Errorf("[%s] failed to obtain certificate: %v", name, err)
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
} }
return errors.New(errMsg)
// double-check that we actually got a certificate, in case there's a bug upstream (see issue #2121)
if certificate.Domain == "" || certificate.Certificate == nil {
return errors.New("returned certificate was empty; probably an unchecked error obtaining it")
} }
// Success - immediately save the certificate resource // Success - immediately save the certificate resource
...@@ -315,23 +321,20 @@ func (c *ACMEClient) Renew(name string) error { ...@@ -315,23 +321,20 @@ func (c *ACMEClient) Renew(name string) error {
acmeMu.Unlock() acmeMu.Unlock()
namesObtaining.Remove([]string{name}) namesObtaining.Remove([]string{name})
if err == nil { if err == nil {
// double-check that we actually got a certificate; check a couple fields
// TODO: This is a temporary workaround for what I think is a bug in the acmev2 package (March 2018)
// but it might not hurt to keep this extra check in place
if newCertMeta.Domain == "" || newCertMeta.Certificate == nil {
err = errors.New("returned certificate was empty; probably an unchecked error renewing it")
} else {
success = true success = true
break break
} }
// If the legal terms were updated and need to be
// agreed to again, we can handle that.
if _, ok := err.(acme.TOSError); ok {
err := c.acmeClient.AgreeToTOS()
if err != nil {
return err
}
continue
} }
// For any other kind of error, wait 10s and try again. // wait a little bit and try again
wait := 10 * time.Second wait := 10 * time.Second
log.Printf("[ERROR] Renewing: %v; trying again in %s", err, wait) log.Printf("[ERROR] Renewing [%v]: %v; trying again in %s", name, err, wait)
time.Sleep(wait) time.Sleep(wait)
} }
......
...@@ -25,7 +25,7 @@ import ( ...@@ -25,7 +25,7 @@ import (
"github.com/klauspost/cpuid" "github.com/klauspost/cpuid"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
// Config describes how TLS should be configured and used. // Config describes how TLS should be configured and used.
...@@ -190,10 +190,15 @@ func NewConfig(inst *caddy.Instance) *Config { ...@@ -190,10 +190,15 @@ func NewConfig(inst *caddy.Instance) *Config {
// it does not load them into memory. If allowPrompts is true, // it does not load them into memory. If allowPrompts is true,
// the user may be shown a prompt. // the user may be shown a prompt.
func (c *Config) ObtainCert(name string, allowPrompts bool) error { func (c *Config) ObtainCert(name string, allowPrompts bool) error {
if !c.Managed || !HostQualifies(name) { skip, err := c.preObtainOrRenewChecks(name, allowPrompts)
if err != nil {
return err
}
if skip {
return nil return nil
} }
// we expect this to be a new (non-existent) site
storage, err := c.StorageFor(c.CAUrl) storage, err := c.StorageFor(c.CAUrl)
if err != nil { if err != nil {
return err return err
...@@ -205,9 +210,6 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error { ...@@ -205,9 +210,6 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error {
if siteExists { if siteExists {
return nil return nil
} }
if c.ACMEEmail == "" {
c.ACMEEmail = getEmail(storage, allowPrompts)
}
client, err := newACMEClient(c, allowPrompts) client, err := newACMEClient(c, allowPrompts)
if err != nil { if err != nil {
...@@ -219,6 +221,14 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error { ...@@ -219,6 +221,14 @@ func (c *Config) ObtainCert(name string, allowPrompts bool) error {
// RenewCert renews the certificate for name using c. It stows the // RenewCert renews the certificate for name using c. It stows the
// renewed certificate and its assets in storage if successful. // renewed certificate and its assets in storage if successful.
func (c *Config) RenewCert(name string, allowPrompts bool) error { func (c *Config) RenewCert(name string, allowPrompts bool) error {
skip, err := c.preObtainOrRenewChecks(name, allowPrompts)
if err != nil {
return err
}
if skip {
return nil
}
client, err := newACMEClient(c, allowPrompts) client, err := newACMEClient(c, allowPrompts)
if err != nil { if err != nil {
return err return err
...@@ -226,6 +236,33 @@ func (c *Config) RenewCert(name string, allowPrompts bool) error { ...@@ -226,6 +236,33 @@ func (c *Config) RenewCert(name string, allowPrompts bool) error {
return client.Renew(name) return client.Renew(name)
} }
// preObtainOrRenewChecks perform a few simple checks before
// obtaining or renewing a certificate with ACME, and returns
// whether this name should be skipped (like if it's not
// managed TLS) as well as any error. It ensures that the
// config is Managed, that the name qualifies for a certificate,
// and that an email address is available.
func (c *Config) preObtainOrRenewChecks(name string, allowPrompts bool) (bool, error) {
if !c.Managed || !HostQualifies(name) {
return true, nil
}
// wildcard certificates require DNS challenge (as of March 2018)
if strings.Contains(name, "*") && c.DNSProvider == "" {
return false, fmt.Errorf("wildcard domain name (%s) requires DNS challenge; use dns subdirective to configure it", name)
}
if c.ACMEEmail == "" {
var err error
c.ACMEEmail, err = getEmail(c, allowPrompts)
if err != nil {
return false, err
}
}
return false, nil
}
// StorageFor obtains a TLS Storage instance for the given CA URL which should // StorageFor obtains a TLS Storage instance for the given CA URL which should
// be unique for every different ACME CA. If a StorageCreator is set on this // be unique for every different ACME CA. If a StorageCreator is set on this
// Config, it will be used. Otherwise the default file storage implementation // Config, it will be used. Otherwise the default file storage implementation
...@@ -476,6 +513,14 @@ func assertConfigsCompatible(cfg1, cfg2 *Config) error { ...@@ -476,6 +513,14 @@ func assertConfigsCompatible(cfg1, cfg2 *Config) error {
if c1.ClientAuth != c2.ClientAuth { if c1.ClientAuth != c2.ClientAuth {
return fmt.Errorf("client authentication policy mismatch") return fmt.Errorf("client authentication policy mismatch")
} }
if c1.ClientAuth != tls.NoClientCert && c2.ClientAuth != tls.NoClientCert && c1.ClientCAs != c2.ClientCAs {
// Two hosts defined on the same listener are not compatible if they
// have ClientAuth enabled, because there's no guarantee beyond the
// hostname which config will be used (because SNI only has server name).
// To prevent clients from bypassing authentication, require that
// ClientAuth be configured in an unambiguous manner.
return fmt.Errorf("multiple hosts requiring client authentication ambiguously configured")
}
return nil return nil
} }
...@@ -511,7 +556,7 @@ func SetDefaultTLSParams(config *Config) { ...@@ -511,7 +556,7 @@ func SetDefaultTLSParams(config *Config) {
// Set default protocol min and max versions - must balance compatibility and security // Set default protocol min and max versions - must balance compatibility and security
if config.ProtocolMinVersion == 0 { if config.ProtocolMinVersion == 0 {
config.ProtocolMinVersion = tls.VersionTLS11 config.ProtocolMinVersion = tls.VersionTLS12
} }
if config.ProtocolMaxVersion == 0 { if config.ProtocolMaxVersion == 0 {
config.ProtocolMaxVersion = tls.VersionTLS12 config.ProtocolMaxVersion = tls.VersionTLS12
...@@ -532,7 +577,8 @@ var supportedKeyTypes = map[string]acme.KeyType{ ...@@ -532,7 +577,8 @@ var supportedKeyTypes = map[string]acme.KeyType{
// Map of supported protocols. // Map of supported protocols.
// HTTP/2 only supports TLS 1.2 and higher. // HTTP/2 only supports TLS 1.2 and higher.
var supportedProtocols = map[string]uint16{ // If updating this map, also update tlsProtocolStringToMap in caddyhttp/fastcgi/fastcgi.go
var SupportedProtocols = map[string]uint16{
"tls1.0": tls.VersionTLS10, "tls1.0": tls.VersionTLS10,
"tls1.1": tls.VersionTLS11, "tls1.1": tls.VersionTLS11,
"tls1.2": tls.VersionTLS12, "tls1.2": tls.VersionTLS12,
...@@ -548,7 +594,7 @@ var supportedProtocols = map[string]uint16{ ...@@ -548,7 +594,7 @@ var supportedProtocols = map[string]uint16{
// it is always added (even though it is not technically a cipher suite). // it is always added (even though it is not technically a cipher suite).
// //
// This map, like any map, is NOT ORDERED. Do not range over this map. // This map, like any map, is NOT ORDERED. Do not range over this map.
var supportedCiphersMap = map[string]uint16{ var SupportedCiphersMap = map[string]uint16{
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
......
...@@ -35,13 +35,14 @@ import ( ...@@ -35,13 +35,14 @@ import (
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"time" "time"
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
// loadPrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes. // loadPrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes.
...@@ -106,7 +107,8 @@ func stapleOCSP(cert *Certificate, pemBundle []byte) error { ...@@ -106,7 +107,8 @@ func stapleOCSP(cert *Certificate, pemBundle []byte) error {
// TODO: Use Storage interface instead of disk directly // TODO: Use Storage interface instead of disk directly
var ocspFileNamePrefix string var ocspFileNamePrefix string
if len(cert.Names) > 0 { if len(cert.Names) > 0 {
ocspFileNamePrefix = cert.Names[0] + "-" firstName := strings.Replace(cert.Names[0], "*", "wildcard_", -1)
ocspFileNamePrefix = firstName + "-"
} }
ocspFileName := ocspFileNamePrefix + fastHash(pemBundle) ocspFileName := ocspFileNamePrefix + fastHash(pemBundle)
ocspCachePath := filepath.Join(ocspFolder, ocspFileName) ocspCachePath := filepath.Join(ocspFolder, ocspFileName)
...@@ -216,10 +218,13 @@ func makeSelfSignedCert(config *Config) error { ...@@ -216,10 +218,13 @@ func makeSelfSignedCert(config *Config) error {
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
} }
var names []string
if ip := net.ParseIP(config.Hostname); ip != nil { if ip := net.ParseIP(config.Hostname); ip != nil {
names = append(names, strings.ToLower(ip.String()))
cert.IPAddresses = append(cert.IPAddresses, ip) cert.IPAddresses = append(cert.IPAddresses, ip)
} else { } else {
cert.DNSNames = append(cert.DNSNames, config.Hostname) names = append(names, strings.ToLower(config.Hostname))
cert.DNSNames = append(cert.DNSNames, strings.ToLower(config.Hostname))
} }
publicKey := func(privKey interface{}) interface{} { publicKey := func(privKey interface{}) interface{} {
...@@ -245,7 +250,7 @@ func makeSelfSignedCert(config *Config) error { ...@@ -245,7 +250,7 @@ func makeSelfSignedCert(config *Config) error {
PrivateKey: privKey, PrivateKey: privKey,
Leaf: cert, Leaf: cert,
}, },
Names: cert.DNSNames, Names: names,
NotAfter: cert.NotAfter, NotAfter: cert.NotAfter,
Hash: hashCertificateChain(chain), Hash: hashCertificateChain(chain),
}) })
......
...@@ -30,14 +30,14 @@ func init() { ...@@ -30,14 +30,14 @@ func init() {
RegisterStorageProvider("file", NewFileStorage) RegisterStorageProvider("file", NewFileStorage)
} }
// storageBasePath is the root path in which all TLS/ACME assets are
// stored. Do not change this value during the lifetime of the program.
var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
// NewFileStorage is a StorageConstructor function that creates a new // NewFileStorage is a StorageConstructor function that creates a new
// Storage instance backed by the local disk. The resulting Storage // Storage instance backed by the local disk. The resulting Storage
// instance is guaranteed to be non-nil if there is no error. // instance is guaranteed to be non-nil if there is no error.
func NewFileStorage(caURL *url.URL) (Storage, error) { func NewFileStorage(caURL *url.URL) (Storage, error) {
// storageBasePath is the root path in which all TLS/ACME assets are
// stored. Do not change this value during the lifetime of the program.
storageBasePath := filepath.Join(caddy.AssetsPath(), "acme")
storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)} storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage} storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
return storage, nil return storage, nil
...@@ -58,25 +58,25 @@ func (s *FileStorage) sites() string { ...@@ -58,25 +58,25 @@ func (s *FileStorage) sites() string {
// site returns the path to the folder containing assets for domain. // site returns the path to the folder containing assets for domain.
func (s *FileStorage) site(domain string) string { func (s *FileStorage) site(domain string) string {
domain = strings.ToLower(domain) domain = fileSafe(domain)
return filepath.Join(s.sites(), domain) return filepath.Join(s.sites(), domain)
} }
// siteCertFile returns the path to the certificate file for domain. // siteCertFile returns the path to the certificate file for domain.
func (s *FileStorage) siteCertFile(domain string) string { func (s *FileStorage) siteCertFile(domain string) string {
domain = strings.ToLower(domain) domain = fileSafe(domain)
return filepath.Join(s.site(domain), domain+".crt") return filepath.Join(s.site(domain), domain+".crt")
} }
// siteKeyFile returns the path to domain's private key file. // siteKeyFile returns the path to domain's private key file.
func (s *FileStorage) siteKeyFile(domain string) string { func (s *FileStorage) siteKeyFile(domain string) string {
domain = strings.ToLower(domain) domain = fileSafe(domain)
return filepath.Join(s.site(domain), domain+".key") return filepath.Join(s.site(domain), domain+".key")
} }
// siteMetaFile returns the path to the domain's asset metadata file. // siteMetaFile returns the path to the domain's asset metadata file.
func (s *FileStorage) siteMetaFile(domain string) string { func (s *FileStorage) siteMetaFile(domain string) string {
domain = strings.ToLower(domain) domain = fileSafe(domain)
return filepath.Join(s.site(domain), domain+".json") return filepath.Join(s.site(domain), domain+".json")
} }
...@@ -90,7 +90,7 @@ func (s *FileStorage) user(email string) string { ...@@ -90,7 +90,7 @@ func (s *FileStorage) user(email string) string {
if email == "" { if email == "" {
email = emptyEmail email = emptyEmail
} }
email = strings.ToLower(email) email = fileSafe(email)
return filepath.Join(s.users(), email) return filepath.Join(s.users(), email)
} }
...@@ -117,6 +117,7 @@ func (s *FileStorage) userRegFile(email string) string { ...@@ -117,6 +117,7 @@ func (s *FileStorage) userRegFile(email string) string {
if fileName == "" { if fileName == "" {
fileName = "registration" fileName = "registration"
} }
fileName = fileSafe(fileName)
return filepath.Join(s.user(email), fileName+".json") return filepath.Join(s.user(email), fileName+".json")
} }
...@@ -131,6 +132,7 @@ func (s *FileStorage) userKeyFile(email string) string { ...@@ -131,6 +132,7 @@ func (s *FileStorage) userKeyFile(email string) string {
if fileName == "" { if fileName == "" {
fileName = "private" fileName = "private"
} }
fileName = fileSafe(fileName)
return filepath.Join(s.user(email), fileName+".key") return filepath.Join(s.user(email), fileName+".key")
} }
...@@ -274,3 +276,29 @@ func (s *FileStorage) MostRecentUserEmail() string { ...@@ -274,3 +276,29 @@ func (s *FileStorage) MostRecentUserEmail() string {
} }
return "" return ""
} }
// fileSafe standardizes and sanitizes str for use in a file path.
func fileSafe(str string) string {
str = strings.ToLower(str)
str = strings.TrimSpace(str)
repl := strings.NewReplacer("..", "",
"/", "",
"\\", "",
// TODO: Consider also replacing "@" with "_at_" (but migrate existing accounts...)
"+", "_plus_",
"%", "",
"$", "",
"`", "",
"~", "",
":", "",
";", "",
"=", "",
"!", "",
"#", "",
"&", "",
"|", "",
"\"", "",
"'", "",
"*", "wildcard_")
return repl.Replace(str)
}
...@@ -14,7 +14,71 @@ ...@@ -14,7 +14,71 @@
package caddytls package caddytls
import (
"path/filepath"
"testing"
)
// *********************************** NOTE ******************************** // *********************************** NOTE ********************************
// Due to circular package dependencies with the storagetest sub package and // Due to circular package dependencies with the storagetest sub package and
// the fact that we want to use that harness to test file storage, the tests // the fact that we want to use that harness to test file storage, most of
// for file storage are done in the storagetest package. // the tests for file storage are done in the storagetest package.
func TestPathBuilders(t *testing.T) {
fs := FileStorage{Path: "test"}
for i, testcase := range []struct {
in, folder, certFile, keyFile, metaFile string
}{
{
in: "example.com",
folder: filepath.Join("test", "sites", "example.com"),
certFile: filepath.Join("test", "sites", "example.com", "example.com.crt"),
keyFile: filepath.Join("test", "sites", "example.com", "example.com.key"),
metaFile: filepath.Join("test", "sites", "example.com", "example.com.json"),
},
{
in: "*.example.com",
folder: filepath.Join("test", "sites", "wildcard_.example.com"),
certFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.crt"),
keyFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.key"),
metaFile: filepath.Join("test", "sites", "wildcard_.example.com", "wildcard_.example.com.json"),
},
{
// prevent directory traversal! very important, esp. with on-demand TLS
// see issue #2092
in: "a/../../../foo",
folder: filepath.Join("test", "sites", "afoo"),
certFile: filepath.Join("test", "sites", "afoo", "afoo.crt"),
keyFile: filepath.Join("test", "sites", "afoo", "afoo.key"),
metaFile: filepath.Join("test", "sites", "afoo", "afoo.json"),
},
{
in: "b\\..\\..\\..\\foo",
folder: filepath.Join("test", "sites", "bfoo"),
certFile: filepath.Join("test", "sites", "bfoo", "bfoo.crt"),
keyFile: filepath.Join("test", "sites", "bfoo", "bfoo.key"),
metaFile: filepath.Join("test", "sites", "bfoo", "bfoo.json"),
},
{
in: "c/foo",
folder: filepath.Join("test", "sites", "cfoo"),
certFile: filepath.Join("test", "sites", "cfoo", "cfoo.crt"),
keyFile: filepath.Join("test", "sites", "cfoo", "cfoo.key"),
metaFile: filepath.Join("test", "sites", "cfoo", "cfoo.json"),
},
} {
if actual := fs.site(testcase.in); actual != testcase.folder {
t.Errorf("Test %d: site folder: Expected '%s' but got '%s'", i, testcase.folder, actual)
}
if actual := fs.siteCertFile(testcase.in); actual != testcase.certFile {
t.Errorf("Test %d: site cert file: Expected '%s' but got '%s'", i, testcase.certFile, actual)
}
if actual := fs.siteKeyFile(testcase.in); actual != testcase.keyFile {
t.Errorf("Test %d: site key file: Expected '%s' but got '%s'", i, testcase.keyFile, actual)
}
if actual := fs.siteMetaFile(testcase.in); actual != testcase.metaFile {
t.Errorf("Test %d: site meta file: Expected '%s' but got '%s'", i, testcase.metaFile, actual)
}
}
}
...@@ -91,7 +91,20 @@ func (s *fileStorageLock) Unlock(name string) error { ...@@ -91,7 +91,20 @@ func (s *fileStorageLock) Unlock(name string) error {
if !ok { if !ok {
return fmt.Errorf("FileStorage: no lock to release for %s", name) return fmt.Errorf("FileStorage: no lock to release for %s", name)
} }
// remove lock file
os.Remove(fw.filename) os.Remove(fw.filename)
// if parent folder is now empty, remove it too to keep it tidy
lockParentFolder := s.storage.site(name)
dir, err := os.Open(lockParentFolder)
if err == nil {
items, _ := dir.Readdirnames(3) // OK to ignore error here
if len(items) == 0 {
os.Remove(lockParentFolder)
}
dir.Close()
}
fw.wg.Done() fw.wg.Done()
delete(fileStorageNameLocks, s.caURL+name) delete(fileStorageNameLocks, s.caURL+name)
return nil return nil
......
...@@ -61,10 +61,9 @@ func (cg configGroup) getConfig(name string) *Config { ...@@ -61,10 +61,9 @@ func (cg configGroup) getConfig(name string) *Config {
} }
} }
// try a config that serves all names (this // try a config that serves all names (the above
// is basically the same as a config defined // loop doesn't try empty string; for hosts defined
// for "*" -- I think -- but the above loop // with only a port, for instance, like ":443")
// doesn't try an empty string)
if config, ok := cg[""]; ok { if config, ok := cg[""]; ok {
return config return config
} }
...@@ -190,18 +189,20 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau ...@@ -190,18 +189,20 @@ func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defau
return return
} }
// if nothing matches and SNI was not provided, use a random // if nothing matches, use a random certificate
// certificate; at least there's a chance this older client // TODO: This is not my favorite behavior; I would rather serve
// can connect, and in the future we won't need this provision // no certificate if SNI is provided and cause a TLS alert, than
// (if SNI is present, it's probably best to just raise a TLS // serve the wrong certificate (but sometimes the 'wrong' cert
// alert by not serving a certificate) // is what is wanted, but in those cases I would prefer that the
if name == "" { // site owner explicitly configure a "default" certificate).
// (See issue 2035; any change to this behavior must account for
// hosts defined like ":443" or "0.0.0.0:443" where the hostname
// is empty or a catch-all IP or something.)
for _, certKey := range cfg.Certificates { for _, certKey := range cfg.Certificates {
defaulted = true
cert = cfg.certCache.cache[certKey] cert = cfg.certCache.cache[certKey]
defaulted = true
return return
} }
}
return return
} }
......
...@@ -27,7 +27,7 @@ func TestGetCertificate(t *testing.T) { ...@@ -27,7 +27,7 @@ func TestGetCertificate(t *testing.T) {
hello := &tls.ClientHelloInfo{ServerName: "example.com"} hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
helloNoSNI := &tls.ClientHelloInfo{} helloNoSNI := &tls.ClientHelloInfo{}
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} // helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} // TODO (see below)
// When cache is empty // When cache is empty
if cert, err := cfg.GetCertificate(hello); err == nil { if cert, err := cfg.GetCertificate(hello); err == nil {
...@@ -69,8 +69,9 @@ func TestGetCertificate(t *testing.T) { ...@@ -69,8 +69,9 @@ func TestGetCertificate(t *testing.T) {
t.Errorf("Expected random cert with no matches, got: %v", cert) t.Errorf("Expected random cert with no matches, got: %v", cert)
} }
// TODO: Re-implement this behavior (it was reverted in #2037)
// When no certificate matches, raise an alert // When no certificate matches, raise an alert
if _, err := cfg.GetCertificate(helloNoMatch); err == nil { // if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err) // t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
} // }
} }
...@@ -16,12 +16,16 @@ package caddytls ...@@ -16,12 +16,16 @@ package caddytls
import ( import (
"crypto/tls" "crypto/tls"
"encoding/json"
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"os"
"strings" "strings"
"github.com/xenolf/lego/acmev2"
) )
const challengeBasePath = "/.well-known/acme-challenge" const challengeBasePath = "/.well-known/acme-challenge"
...@@ -38,6 +42,13 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str ...@@ -38,6 +42,13 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str
if DisableHTTPChallenge { if DisableHTTPChallenge {
return false return false
} }
// see if another instance started the HTTP challenge for this name
if tryDistributedChallengeSolver(w, r) {
return true
}
// otherwise, if we aren't getting the name, then ignore this challenge
if !namesObtaining.Has(r.Host) { if !namesObtaining.Has(r.Host) {
return false return false
} }
...@@ -70,3 +81,40 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str ...@@ -70,3 +81,40 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost str
return true return true
} }
// tryDistributedChallengeSolver checks to see if this challenge
// request was initiated by another instance that shares file
// storage, and attempts to complete the challenge for it. It
// returns true if the challenge was handled; false otherwise.
func tryDistributedChallengeSolver(w http.ResponseWriter, r *http.Request) bool {
filePath := distributedHTTPSolver{}.challengeTokensPath(r.Host)
f, err := os.Open(filePath)
if err != nil {
if !os.IsNotExist(err) {
log.Printf("[ERROR][%s] Opening distributed challenge token file: %v", r.Host, err)
}
return false
}
defer f.Close()
var chalInfo challengeInfo
err = json.NewDecoder(f).Decode(&chalInfo)
if err != nil {
log.Printf("[ERROR][%s] Decoding challenge token file %s (corrupted?): %v", r.Host, filePath, err)
return false
}
// this part borrowed from xenolf/lego's built-in HTTP-01 challenge solver (March 2018)
challengeReqPath := acme.HTTP01ChallengePath(chalInfo.Token)
if r.URL.Path == challengeReqPath &&
strings.HasPrefix(r.Host, chalInfo.Domain) &&
r.Method == "GET" {
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte(chalInfo.KeyAuth))
r.Close = true
log.Printf("[INFO][%s] Served key authentication", chalInfo.Domain)
return true
}
return false
}
...@@ -334,6 +334,7 @@ func DeleteOldStapleFiles() { ...@@ -334,6 +334,7 @@ func DeleteOldStapleFiles() {
if err != nil { if err != nil {
log.Printf("[ERROR] Purging corrupt staple file %s: %v", stapleFile, err) log.Printf("[ERROR] Purging corrupt staple file %s: %v", stapleFile, err)
} }
continue
} }
if time.Now().After(resp.NextUpdate) { if time.Now().After(resp.NextUpdate) {
// response has expired; delete it // response has expired; delete it
......
...@@ -107,19 +107,19 @@ func setupTLS(c *caddy.Controller) error { ...@@ -107,19 +107,19 @@ func setupTLS(c *caddy.Controller) error {
case "protocols": case "protocols":
args := c.RemainingArgs() args := c.RemainingArgs()
if len(args) == 1 { if len(args) == 1 {
value, ok := supportedProtocols[strings.ToLower(args[0])] value, ok := SupportedProtocols[strings.ToLower(args[0])]
if !ok { if !ok {
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0]) return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0])
} }
config.ProtocolMinVersion, config.ProtocolMaxVersion = value, value config.ProtocolMinVersion, config.ProtocolMaxVersion = value, value
} else { } else {
value, ok := supportedProtocols[strings.ToLower(args[0])] value, ok := SupportedProtocols[strings.ToLower(args[0])]
if !ok { if !ok {
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0]) return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0])
} }
config.ProtocolMinVersion = value config.ProtocolMinVersion = value
value, ok = supportedProtocols[strings.ToLower(args[1])] value, ok = SupportedProtocols[strings.ToLower(args[1])]
if !ok { if !ok {
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[1]) return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[1])
} }
...@@ -130,7 +130,7 @@ func setupTLS(c *caddy.Controller) error { ...@@ -130,7 +130,7 @@ func setupTLS(c *caddy.Controller) error {
} }
case "ciphers": case "ciphers":
for c.NextArg() { for c.NextArg() {
value, ok := supportedCiphersMap[strings.ToUpper(c.Val())] value, ok := SupportedCiphersMap[strings.ToUpper(c.Val())]
if !ok { if !ok {
return c.Errf("Wrong cipher name or cipher not supported: '%s'", c.Val()) return c.Errf("Wrong cipher name or cipher not supported: '%s'", c.Val())
} }
...@@ -210,8 +210,21 @@ func setupTLS(c *caddy.Controller) error { ...@@ -210,8 +210,21 @@ func setupTLS(c *caddy.Controller) error {
} }
case "must_staple": case "must_staple":
config.MustStaple = true config.MustStaple = true
case "wildcard":
if !HostQualifies(config.Hostname) {
return c.Errf("Hostname '%s' does not qualify for managed TLS, so cannot manage wildcard certificate for it", config.Hostname)
}
if strings.Contains(config.Hostname, "*") {
return c.Errf("Cannot convert domain name '%s' to a valid wildcard: already has a wildcard label", config.Hostname)
}
parts := strings.Split(config.Hostname, ".")
if len(parts) < 3 {
return c.Errf("Cannot convert domain name '%s' to a valid wildcard: too few labels", config.Hostname)
}
parts[0] = "*"
config.Hostname = strings.Join(parts, ".")
default: default:
return c.Errf("Unknown keyword '%s'", c.Val()) return c.Errf("Unknown subdirective '%s'", c.Val())
} }
} }
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"testing" "testing"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
...@@ -67,8 +67,8 @@ func TestSetupParseBasic(t *testing.T) { ...@@ -67,8 +67,8 @@ func TestSetupParseBasic(t *testing.T) {
} }
// Security defaults // Security defaults
if cfg.ProtocolMinVersion != tls.VersionTLS11 { if cfg.ProtocolMinVersion != tls.VersionTLS12 {
t.Errorf("Expected 'tls1.1 (0x0302)' as ProtocolMinVersion, got %#v", cfg.ProtocolMinVersion) t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMinVersion, got %#v", cfg.ProtocolMinVersion)
} }
if cfg.ProtocolMaxVersion != tls.VersionTLS12 { if cfg.ProtocolMaxVersion != tls.VersionTLS12 {
t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", cfg.ProtocolMaxVersion) t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", cfg.ProtocolMaxVersion)
......
...@@ -58,7 +58,8 @@ type Locker interface { ...@@ -58,7 +58,8 @@ type Locker interface {
// successfully obtained the lock (no Waiter value was returned) // successfully obtained the lock (no Waiter value was returned)
// should call this method, and it should be called only after // should call this method, and it should be called only after
// the obtain/renew and store are finished, even if there was // the obtain/renew and store are finished, even if there was
// an error (or a timeout). // an error (or a timeout). Unlock should also clean up any
// unused resources allocated during TryLock.
Unlock(name string) error Unlock(name string) error
} }
......
...@@ -30,26 +30,35 @@ package caddytls ...@@ -30,26 +30,35 @@ package caddytls
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io/ioutil"
"log"
"net" "net"
"os"
"path/filepath"
"strings" "strings"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
// HostQualifies returns true if the hostname alone // HostQualifies returns true if the hostname alone
// appears eligible for automatic HTTPS. For example, // appears eligible for automatic HTTPS. For example:
// localhost, empty hostname, and IP addresses are // localhost, empty hostname, and IP addresses are
// not eligible because we cannot obtain certificates // not eligible because we cannot obtain certificates
// for those names. // for those names. Wildcard names are allowed, as long
// as they conform to CABF requirements (only one wildcard
// label, and it must be the left-most label).
func HostQualifies(hostname string) bool { func HostQualifies(hostname string) bool {
return hostname != "localhost" && // localhost is ineligible return hostname != "localhost" && // localhost is ineligible
// hostname must not be empty // hostname must not be empty
strings.TrimSpace(hostname) != "" && strings.TrimSpace(hostname) != "" &&
// must not contain wildcard (*) characters (until CA supports it) // only one wildcard label allowed, and it must be left-most
!strings.Contains(hostname, "*") && (!strings.Contains(hostname, "*") ||
(strings.Count(hostname, "*") == 1 &&
strings.HasPrefix(hostname, "*."))) &&
// must not start or end with a dot // must not start or end with a dot
!strings.HasPrefix(hostname, ".") && !strings.HasPrefix(hostname, ".") &&
...@@ -88,39 +97,125 @@ func Revoke(host string) error { ...@@ -88,39 +97,125 @@ func Revoke(host string) error {
return client.Revoke(host) return client.Revoke(host)
} }
// tlsSNISolver is a type that can solve TLS-SNI challenges using // TODO: tls-sni challenge was removed in January 2018, but a variant of it might return
// an existing listener and our custom, in-memory certificate cache. // // tlsSNISolver is a type that can solve TLS-SNI challenges using
type tlsSNISolver struct { // // an existing listener and our custom, in-memory certificate cache.
certCache *certificateCache // type tlsSNISolver struct {
// certCache *certificateCache
// }
// // Present adds the challenge certificate to the cache.
// func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
// cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
// if err != nil {
// return err
// }
// certHash := hashCertificateChain(cert.Certificate)
// s.certCache.Lock()
// s.certCache.cache[acmeDomain] = Certificate{
// Certificate: cert,
// Names: []string{acmeDomain},
// Hash: certHash, // perhaps not necesssary
// }
// s.certCache.Unlock()
// return nil
// }
// // CleanUp removes the challenge certificate from the cache.
// func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
// _, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
// if err != nil {
// return err
// }
// s.certCache.Lock()
// delete(s.certCache.cache, acmeDomain)
// s.certCache.Unlock()
// return nil
// }
// distributedHTTPSolver allows the HTTP-01 challenge to be solved by
// an instance other than the one which initiated it. This is useful
// behind load balancers or in other cluster/fleet configurations.
// The only requirement is that this (the initiating) instance share
// the $CADDYPATH/acme folder with the instance that will complete
// the challenge. Mounting the folder locally should be sufficient.
//
// Obviously, the instance which completes the challenge must be
// serving on the HTTPChallengePort to receive and handle the request.
// The HTTP server which receives it must check if a file exists, e.g.:
// $CADDYPATH/acme/challenge_tokens/example.com.json, and if so,
// decode it and use it to serve up the correct response. Caddy's HTTP
// server does this by default.
//
// So as long as the folder is shared, this will just work. There are
// no other requirements. The instances may be on other machines or
// even other networks, as long as they share the folder as part of
// the local file system.
//
// This solver works by persisting the token and keyauth information
// to disk in the shared folder when the authorization is presented,
// and then deletes it when it is cleaned up.
type distributedHTTPSolver struct {
// The distributed HTTPS solver only works if an instance (either
// this one or another one) is already listening and serving on the
// HTTPChallengePort. If not -- for example: if this is the only
// instance, and it is just starting up and hasn't started serving
// yet -- then we still need a listener open with an HTTP server
// to handle the challenge request. Set this field to have the
// standard HTTPProviderServer open its listener for the duration
// of the challenge. Make sure to configure its listen address
// correctly.
httpProviderServer *acme.HTTPProviderServer
}
type challengeInfo struct {
Domain, Token, KeyAuth string
} }
// Present adds the challenge certificate to the cache. // Present adds the challenge certificate to the cache.
func (s tlsSNISolver) Present(domain, token, keyAuth string) error { func (dhs distributedHTTPSolver) Present(domain, token, keyAuth string) error {
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth) if dhs.httpProviderServer != nil {
err := dhs.httpProviderServer.Present(domain, token, keyAuth)
if err != nil {
return fmt.Errorf("presenting with standard HTTP provider server: %v", err)
}
}
err := os.MkdirAll(dhs.challengeTokensBasePath(), 0755)
if err != nil { if err != nil {
return err return err
} }
certHash := hashCertificateChain(cert.Certificate)
s.certCache.Lock() infoBytes, err := json.Marshal(challengeInfo{
s.certCache.cache[acmeDomain] = Certificate{ Domain: domain,
Certificate: cert, Token: token,
Names: []string{acmeDomain}, KeyAuth: keyAuth,
Hash: certHash, // perhaps not necesssary })
if err != nil {
return err
} }
s.certCache.Unlock()
return nil return ioutil.WriteFile(dhs.challengeTokensPath(domain), infoBytes, 0644)
} }
// CleanUp removes the challenge certificate from the cache. // CleanUp removes the challenge certificate from the cache.
func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error { func (dhs distributedHTTPSolver) CleanUp(domain, token, keyAuth string) error {
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth) if dhs.httpProviderServer != nil {
err := dhs.httpProviderServer.CleanUp(domain, token, keyAuth)
if err != nil { if err != nil {
return err log.Printf("[ERROR] Cleaning up standard HTTP provider server: %v", err)
} }
s.certCache.Lock() }
delete(s.certCache.cache, acmeDomain) return os.Remove(dhs.challengeTokensPath(domain))
s.certCache.Unlock() }
return nil
func (dhs distributedHTTPSolver) challengeTokensPath(domain string) string {
domainFile := strings.Replace(strings.ToLower(domain), "*", "wildcard_", -1)
return filepath.Join(dhs.challengeTokensBasePath(), domainFile+".json")
}
func (dhs distributedHTTPSolver) challengeTokensBasePath() string {
return filepath.Join(caddy.AssetsPath(), "acme", "challenge_tokens")
} }
// ConfigHolder is any type that has a Config; it presumably is // ConfigHolder is any type that has a Config; it presumably is
......
...@@ -18,7 +18,7 @@ import ( ...@@ -18,7 +18,7 @@ import (
"os" "os"
"testing" "testing"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
func TestHostQualifies(t *testing.T) { func TestHostQualifies(t *testing.T) {
...@@ -37,7 +37,10 @@ func TestHostQualifies(t *testing.T) { ...@@ -37,7 +37,10 @@ func TestHostQualifies(t *testing.T) {
{"0.0.0.0", false}, {"0.0.0.0", false},
{"", false}, {"", false},
{" ", false}, {" ", false},
{"*.example.com", false}, {"*.example.com", true},
{"*.*.example.com", false},
{"sub.*.example.com", false},
{"*sub.example.com", false},
{".com", false}, {".com", false},
{"example.com.", false}, {"example.com.", false},
{"localhost", false}, {"localhost", false},
...@@ -77,7 +80,10 @@ func TestQualifiesForManagedTLS(t *testing.T) { ...@@ -77,7 +80,10 @@ func TestQualifiesForManagedTLS(t *testing.T) {
{holder{host: "localhost", cfg: new(Config)}, false}, {holder{host: "localhost", cfg: new(Config)}, false},
{holder{host: "123.44.3.21", cfg: new(Config)}, false}, {holder{host: "123.44.3.21", cfg: new(Config)}, false},
{holder{host: "example.com", cfg: new(Config)}, true}, {holder{host: "example.com", cfg: new(Config)}, true},
{holder{host: "*.example.com", cfg: new(Config)}, false}, {holder{host: "*.example.com", cfg: new(Config)}, true},
{holder{host: "*.*.example.com", cfg: new(Config)}, false},
{holder{host: "*sub.example.com", cfg: new(Config)}, false},
{holder{host: "sub.*.example.com", cfg: new(Config)}, false},
{holder{host: "example.com", cfg: &Config{Manual: true}}, false}, {holder{host: "example.com", cfg: &Config{Manual: true}}, false},
{holder{host: "example.com", cfg: &Config{ACMEEmail: "off"}}, false}, {holder{host: "example.com", cfg: &Config{ACMEEmail: "off"}}, false},
{holder{host: "example.com", cfg: &Config{ACMEEmail: "foo@bar.com"}}, true}, {holder{host: "example.com", cfg: &Config{ACMEEmail: "foo@bar.com"}}, true},
......
...@@ -27,7 +27,7 @@ import ( ...@@ -27,7 +27,7 @@ import (
"os" "os"
"strings" "strings"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
// User represents a Let's Encrypt user account. // User represents a Let's Encrypt user account.
...@@ -67,43 +67,82 @@ func newUser(email string) (User, error) { ...@@ -67,43 +67,82 @@ func newUser(email string) (User, error) {
return user, nil return user, nil
} }
// getEmail does everything it can to obtain an email // getEmail does everything it can to obtain an email address
// address from the user within the scope of storage // from the user within the scope of memory and storage to use
// to use for ACME TLS. If it cannot get an email // for ACME TLS. If it cannot get an email address, it returns
// address, it returns empty string. (It will warn the // empty string. (If user is present, it will warn the user of
// user of the consequences of an empty email.) This // the consequences of an empty email.) This function MAY prompt
// function MAY prompt the user for input. If userPresent // the user for input. If userPresent is false, the operator
// is false, the operator will NOT be prompted and an // will NOT be prompted and an empty email may be returned.
// empty email may be returned. // If the user is prompted, a new User will be created and
func getEmail(storage Storage, userPresent bool) string { // stored in storage according to the email address they
// provided (which might be blank).
func getEmail(cfg *Config, userPresent bool) (string, error) {
storage, err := cfg.StorageFor(cfg.CAUrl)
if err != nil {
return "", err
}
// First try memory (command line flag or typed by user previously) // First try memory (command line flag or typed by user previously)
leEmail := DefaultEmail leEmail := DefaultEmail
// Then try to get most recent user email from storage
if leEmail == "" { if leEmail == "" {
// Then try to get most recent user email
leEmail = storage.MostRecentUserEmail() leEmail = storage.MostRecentUserEmail()
// Save for next time DefaultEmail = leEmail // save for next time
DefaultEmail = leEmail
} }
// Looks like there is no email address readily available,
// so we will have to ask the user if we can.
if leEmail == "" && userPresent { if leEmail == "" && userPresent {
// Alas, we must bother the user and ask for an email address; // evidently, no User data was present in storage;
// if they proceed they also agree to the SA. // thus we must make a new User so that we can get
// the Terms of Service URL via our ACME client, phew!
user, err := newUser("")
if err != nil {
return "", err
}
// get the agreement URL
agreementURL := agreementTestURL
if agreementURL == "" {
// we call acme.NewClient directly because newACMEClient
// would require that we already know the user's email
caURL := DefaultCAUrl
if cfg.CAUrl != "" {
caURL = cfg.CAUrl
}
tempClient, err := acme.NewClient(caURL, user, "")
if err != nil {
return "", fmt.Errorf("making ACME client to get ToS URL: %v", err)
}
agreementURL = tempClient.GetToSURL()
}
// prompt the user for an email address and terms agreement
reader := bufio.NewReader(stdin) reader := bufio.NewReader(stdin)
fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.") promptUserAgreement(agreementURL)
fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:") fmt.Println("Please enter your email address to signify agreement and to be notified")
fmt.Println(" " + saURL) // TODO: Show current SA link fmt.Println("in case of issues. You can leave it blank, but we don't recommend it.")
fmt.Println("Please enter your email address so you can recover your account if needed.") fmt.Print(" Email address: ")
fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.")
fmt.Print("Email address: ")
var err error
leEmail, err = reader.ReadString('\n') leEmail, err = reader.ReadString('\n')
if err != nil { if err != nil && err != io.EOF {
return "" return "", fmt.Errorf("reading email address: %v", err)
} }
leEmail = strings.TrimSpace(leEmail) leEmail = strings.TrimSpace(leEmail)
DefaultEmail = leEmail DefaultEmail = leEmail
Agreed = true Agreed = true
// save the new user to preserve this for next time
user.Email = leEmail
err = saveUser(storage, user)
if err != nil {
return "", err
}
} }
return strings.ToLower(leEmail)
// lower-casing the email is important for consistency
return strings.ToLower(leEmail), nil
} }
// getUser loads the user with the given email from disk // getUser loads the user with the given email from disk
...@@ -154,18 +193,21 @@ func saveUser(storage Storage, user User) error { ...@@ -154,18 +193,21 @@ func saveUser(storage Storage, user User) error {
return err return err
} }
// promptUserAgreement prompts the user to agree to the agreement // promptUserAgreement simply outputs the standard user
// at agreementURL via stdin. If the agreement has changed, then pass // agreement prompt with the given agreement URL.
// true as the second argument. If this is the user's first time // It outputs a newline after the message.
// agreeing, pass false. It returns whether the user agreed or not. func promptUserAgreement(agreementURL string) {
func promptUserAgreement(agreementURL string, changed bool) bool { const userAgreementPrompt = `Your sites will be served over HTTPS automatically using Let's Encrypt.
if changed { By continuing, you agree to the Let's Encrypt Subscriber Agreement at:`
fmt.Printf("The Let's Encrypt Subscriber Agreement has changed:\n %s\n", agreementURL) fmt.Printf("\n\n%s\n %s\n", userAgreementPrompt, agreementURL)
fmt.Print("Do you agree to the new terms? (y/n): ") }
} else {
fmt.Printf("To continue, you must agree to the Let's Encrypt Subscriber Agreement:\n %s\n", agreementURL) // askUserAgreement prompts the user to agree to the agreement
// at the given agreement URL via stdin. It returns whether the
// user agreed or not.
func askUserAgreement(agreementURL string) bool {
promptUserAgreement(agreementURL)
fmt.Print("Do you agree to the terms? (y/n): ") fmt.Print("Do you agree to the terms? (y/n): ")
}
reader := bufio.NewReader(stdin) reader := bufio.NewReader(stdin)
answer, err := reader.ReadString('\n') answer, err := reader.ReadString('\n')
...@@ -177,14 +219,15 @@ func promptUserAgreement(agreementURL string, changed bool) bool { ...@@ -177,14 +219,15 @@ func promptUserAgreement(agreementURL string, changed bool) bool {
return answer == "y" || answer == "yes" return answer == "y" || answer == "yes"
} }
// agreementTestURL is set during tests to skip requiring
// setting up an entire ACME CA endpoint.
var agreementTestURL string
// stdin is used to read the user's input if prompted; // stdin is used to read the user's input if prompted;
// this is changed by tests during tests. // this is changed by tests during tests.
var stdin = io.ReadWriter(os.Stdin) var stdin = io.ReadWriter(os.Stdin)
// The name of the folder for accounts where the email // The name of the folder for accounts where the email
// address was not provided; default 'username' if you will. // address was not provided; default 'username' if you will,
// but only for local/storage use, not with the CA.
const emptyEmail = "default" const emptyEmail = "default"
// TODO: After Boulder implements the 'meta' field of the directory,
// we can get this link dynamically.
const saURL = "https://acme-v01.api.letsencrypt.org/terms"
...@@ -20,13 +20,14 @@ import ( ...@@ -20,13 +20,14 @@ import (
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"io" "io"
"path/filepath"
"strings" "strings"
"testing" "testing"
"time" "time"
"os" "os"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acmev2"
) )
func TestUser(t *testing.T) { func TestUser(t *testing.T) {
...@@ -135,7 +136,13 @@ func TestGetUserAlreadyExists(t *testing.T) { ...@@ -135,7 +136,13 @@ func TestGetUserAlreadyExists(t *testing.T) {
} }
func TestGetEmail(t *testing.T) { func TestGetEmail(t *testing.T) {
storageBasePath = testStorage.Path // to contain calls that create a new Storage... // ensure storage (via StorageFor) uses the local testdata folder that we delete later
origCaddypath := os.Getenv("CADDYPATH")
os.Setenv("CADDYPATH", "./testdata")
defer os.Setenv("CADDYPATH", origCaddypath)
agreementTestURL = "(none - testing)"
defer func() { agreementTestURL = "" }()
// let's not clutter up the output // let's not clutter up the output
origStdout := os.Stdout origStdout := os.Stdout
...@@ -146,7 +153,10 @@ func TestGetEmail(t *testing.T) { ...@@ -146,7 +153,10 @@ func TestGetEmail(t *testing.T) {
DefaultEmail = "test2@foo.com" DefaultEmail = "test2@foo.com"
// Test1: Use default email from flag (or user previously typing it) // Test1: Use default email from flag (or user previously typing it)
actual := getEmail(testStorage, true) actual, err := getEmail(testConfig, true)
if err != nil {
t.Fatalf("getEmail (1) error: %v", err)
}
if actual != DefaultEmail { if actual != DefaultEmail {
t.Errorf("Did not get correct email from memory; expected '%s' but got '%s'", DefaultEmail, actual) t.Errorf("Did not get correct email from memory; expected '%s' but got '%s'", DefaultEmail, actual)
} }
...@@ -154,16 +164,19 @@ func TestGetEmail(t *testing.T) { ...@@ -154,16 +164,19 @@ func TestGetEmail(t *testing.T) {
// Test2: Get input from user // Test2: Get input from user
DefaultEmail = "" DefaultEmail = ""
stdin = new(bytes.Buffer) stdin = new(bytes.Buffer)
_, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n")) _, err = io.Copy(stdin, strings.NewReader("test3@foo.com\n"))
if err != nil { if err != nil {
t.Fatalf("Could not simulate user input, error: %v", err) t.Fatalf("Could not simulate user input, error: %v", err)
} }
actual = getEmail(testStorage, true) actual, err = getEmail(testConfig, true)
if err != nil {
t.Fatalf("getEmail (2) error: %v", err)
}
if actual != "test3@foo.com" { if actual != "test3@foo.com" {
t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual) t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual)
} }
// Test3: Get most recent email from before // Test3: Get most recent email from before (in storage)
DefaultEmail = "" DefaultEmail = ""
for i, eml := range []string{ for i, eml := range []string{
"TEST4-3@foo.com", // test case insensitivity "TEST4-3@foo.com", // test case insensitivity
...@@ -189,14 +202,20 @@ func TestGetEmail(t *testing.T) { ...@@ -189,14 +202,20 @@ func TestGetEmail(t *testing.T) {
t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err) t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
} }
} }
actual = getEmail(testStorage, true) actual, err = getEmail(testConfig, true)
if err != nil {
t.Fatalf("getEmail (3) error: %v", err)
}
if actual != "test4-3@foo.com" { if actual != "test4-3@foo.com" {
t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual) t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual)
} }
} }
var testStorage = &FileStorage{Path: "./testdata"} var (
testStorageBase = "./testdata" // ephemeral folder that gets deleted after tests finish
testCAHost = "localhost"
testConfig = &Config{CAUrl: "http://" + testCAHost + "/directory", StorageProvider: "file"}
testStorage = &FileStorage{Path: filepath.Join(testStorageBase, "acme", testCAHost)}
)
func (s *FileStorage) clean() error { func (s *FileStorage) clean() error { return os.RemoveAll(testStorageBase) }
return os.RemoveAll(s.Path)
}
CHANGES CHANGES
0.10.14 (April 19, 2018)
- tls: Fix error handling bug when obtaining certificates
0.10.13 (April 18, 2018)
- New third-party plugin: supervisor
- Updated QUIC
- proxy: Fix transparent pass-thru of X-Forwarded-For
- proxy: Configurable timeout to upstream
- rewrite: Now supports regular expressions on single-line
- tls: StrictHostMatching mode to prevent client auth bypass
- tls: Disable client auth when using QUIC
- tls: Require same client auth cert pools per hostname
- tls: Prevent On-Demand TLS directory traversal
- tls: Fix empty files when using ACME fails to obtain cert
- Fixed test broken by 1.1.1.1 resolving
- Improved Caddyfile parser robustness by fuzzing
0.10.12 (March 27, 2018)
- Switch to Let's Encrypt ACMEv2 production endpoint
- Support for automated wildcard certificates
- Support distributed solving of HTTP-01 challenge
- New {labelN}, {tls_cipher}, and {tls_version} placeholders
- Curly braces can now be escaped when not used as placeholders
- New third-party plugin: geoip
- Updated QUIC
- fastcgi: Add SSL_CIPHER and SSL_PROTOCOL environment variables
- log: New 'except' subdirective to exempt paths from logging
- startup/shutdown: Removed in favor of 'on'
- tls: Default minimum version is TLS 1.2
- tls: Revert to fallback cert if no cert matches SNI
- tls: New 'wildcard' subdirective to force automated wildcard cert
- Several significant bug fixes and improvements!
0.10.11 (February 20, 2018)
- Built with Go 1.10
- Reusable snippets for the Caddyfile
- Updated QUIC
- Auto-HTTPS certificates may be shared by multiple instances
- Expand globbed values in -conf flag
- Swap behavior of SIGTERM and SIGQUIT; ignore SIGHUP
- 9 new DNS provider plugins for the ACME DNS challenge
- New placeholder for {<Response-Header} values
- basicauth: Username put in {user} placeholder
- fastcgi: GET requests can now send a body
- proxy: Service discovery with DNS SRV load balancing
- request_id: Allow reusing request ID from header field
- tls: Improved efficiency of many certificates and reloads
- tls: Raise error if conflicting TLS configurations collide
- tls: Raise TLS alert if SNI used and no cert matched
- tls: Reject OCSP responses that expire after the certificate
- tls: Clients can use SNI to request a specific certificate
- tls: Add option for backend to approve on-demand certificate
- tls: Synchronize maintenance of shared, managed certificates
- Numerous fabulous bug fixes
0.10.10 (October 9, 2017) 0.10.10 (October 9, 2017)
- Built with Go 1.9.1 - Built with Go 1.9.1
- Removed Caddy-Sponsors header - Removed Caddy-Sponsors header
......
CADDY 0.10.10 CADDY 0.10.14
Website Website
https://caddyserver.com https://caddyserver.com
...@@ -32,9 +32,9 @@ the project wiki: https://github.com/mholt/caddy/wiki ...@@ -32,9 +32,9 @@ the project wiki: https://github.com/mholt/caddy/wiki
And thanks - you're awesome! And thanks - you're awesome!
If you think Caddy is awesome too, consider sponsoring it: If you think Caddy is awesome too, consider sponsoring it:
https://caddyserver.com/pricing - and help keep Caddy free https://caddyserver.com/sponsor - and help keep Caddy free
for personal use. for personal use.
--- ---
(c) 2015-2017 Light Code Labs, LLC (c) 2015-2018 Light Code Labs, LLC
...@@ -39,7 +39,7 @@ var ( ...@@ -39,7 +39,7 @@ var (
// eventHooks is a map of hook name to Hook. All hooks plugins // eventHooks is a map of hook name to Hook. All hooks plugins
// must have a name. // must have a name.
eventHooks = sync.Map{} eventHooks = &sync.Map{}
// parsingCallbacks maps server type to map of directive // parsingCallbacks maps server type to map of directive
// to list of callback functions. These aren't really // to list of callback functions. These aren't really
...@@ -296,6 +296,36 @@ func EmitEvent(event EventName, info interface{}) { ...@@ -296,6 +296,36 @@ func EmitEvent(event EventName, info interface{}) {
}) })
} }
// cloneEventHooks return a clone of the event hooks *sync.Map
func cloneEventHooks() *sync.Map {
c := &sync.Map{}
eventHooks.Range(func(k, v interface{}) bool {
c.Store(k, v)
return true
})
return c
}
// purgeEventHooks purges all event hooks from the map
func purgeEventHooks() {
eventHooks.Range(func(k, _ interface{}) bool {
eventHooks.Delete(k)
return true
})
}
// restoreEventHooks restores eventHooks with a provided *sync.Map
func restoreEventHooks(m *sync.Map) {
// Purge old event hooks
purgeEventHooks()
// Restore event hooks
m.Range(func(k, v interface{}) bool {
eventHooks.Store(k, v)
return true
})
}
// ParsingCallback is a function that is called after // ParsingCallback is a function that is called after
// a directive's setup functions have been executed // a directive's setup functions have been executed
// for all the server blocks. // for all the server blocks.
......
...@@ -83,9 +83,17 @@ func trapSignalsPosix() { ...@@ -83,9 +83,17 @@ func trapSignalsPosix() {
caddyfileToUse = newCaddyfile caddyfileToUse = newCaddyfile
} }
// Backup old event hooks
oldEventHooks := cloneEventHooks()
// Purge the old event hooks
purgeEventHooks()
// Kick off the restart; our work is done // Kick off the restart; our work is done
_, err = inst.Restart(caddyfileToUse) _, err = inst.Restart(caddyfileToUse)
if err != nil { if err != nil {
restoreEventHooks(oldEventHooks)
log.Printf("[ERROR] SIGUSR1: %v", err) log.Printf("[ERROR] SIGUSR1: %v", err)
} }
......
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package startupshutdown
import (
"fmt"
"strings"
"github.com/google/uuid"
"github.com/mholt/caddy"
"github.com/mholt/caddy/onevent/hook"
)
func init() {
caddy.RegisterPlugin("startup", caddy.Plugin{Action: Startup})
caddy.RegisterPlugin("shutdown", caddy.Plugin{Action: Shutdown})
}
// Startup (an alias for 'on startup') registers a startup callback to execute during server start.
func Startup(c *caddy.Controller) error {
config, err := onParse(c, caddy.InstanceStartupEvent)
if err != nil {
return c.ArgErr()
}
// Register Event Hooks.
c.OncePerServerBlock(func() error {
for _, cfg := range config {
caddy.RegisterEventHook("on-"+cfg.ID, cfg.Hook)
}
return nil
})
fmt.Println("NOTICE: Startup directive will be removed in a later version. Please migrate to 'on startup'")
return nil
}
// Shutdown (an alias for 'on shutdown') registers a shutdown callback to execute during server start.
func Shutdown(c *caddy.Controller) error {
config, err := onParse(c, caddy.ShutdownEvent)
if err != nil {
return c.ArgErr()
}
// Register Event Hooks.
for _, cfg := range config {
caddy.RegisterEventHook("on-"+cfg.ID, cfg.Hook)
}
fmt.Println("NOTICE: Shutdown directive will be removed in a later version. Please migrate to 'on shutdown'")
return nil
}
func onParse(c *caddy.Controller, event caddy.EventName) ([]*hook.Config, error) {
var config []*hook.Config
for c.Next() {
cfg := new(hook.Config)
args := c.RemainingArgs()
if len(args) == 0 {
return config, c.ArgErr()
}
// Configure Event.
cfg.Event = event
// Assign an unique ID.
cfg.ID = uuid.New().String()
// Extract command and arguments.
command, args, err := caddy.SplitCommandAndArgs(strings.Join(args, " "))
if err != nil {
return config, c.Err(err.Error())
}
cfg.Command = command
cfg.Args = args
config = append(config, cfg)
}
return config, nil
}
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package startupshutdown
import (
"testing"
"github.com/mholt/caddy"
)
func TestStartup(t *testing.T) {
tests := []struct {
name string
input string
shouldErr bool
}{
{name: "noInput", input: "startup", shouldErr: true},
{name: "startup", input: "startup cmd arg", shouldErr: false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
c := caddy.NewTestController("", test.input)
err := Startup(c)
if err == nil && test.shouldErr {
t.Error("Test didn't error, but it should have")
} else if err != nil && !test.shouldErr {
t.Errorf("Test errored, but it shouldn't have; got '%v'", err)
}
})
}
}
func TestShutdown(t *testing.T) {
tests := []struct {
name string
input string
shouldErr bool
}{
{name: "noInput", input: "shutdown", shouldErr: true},
{name: "shutdown", input: "shutdown cmd arg", shouldErr: false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
c := caddy.NewTestController("", test.input)
err := Shutdown(c)
if err == nil && test.shouldErr {
t.Error("Test didn't error, but it should have")
} else if err != nil && !test.shouldErr {
t.Errorf("Test errored, but it shouldn't have; got '%v'", err)
}
})
}
}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build !amd64,!s390x // +build !amd64
package aes12 package aes12
......
...@@ -8,19 +8,20 @@ import ( ...@@ -8,19 +8,20 @@ import (
var bufferPool sync.Pool var bufferPool sync.Pool
func getPacketBuffer() []byte { func getPacketBuffer() *[]byte {
return bufferPool.Get().([]byte) return bufferPool.Get().(*[]byte)
} }
func putPacketBuffer(buf []byte) { func putPacketBuffer(buf *[]byte) {
if cap(buf) != int(protocol.MaxReceivePacketSize) { if cap(*buf) != int(protocol.MaxReceivePacketSize) {
panic("putPacketBuffer called with packet of wrong size!") panic("putPacketBuffer called with packet of wrong size!")
} }
bufferPool.Put(buf[:0]) bufferPool.Put(buf)
} }
func init() { func init() {
bufferPool.New = func() interface{} { bufferPool.New = func() interface{} {
return make([]byte, 0, protocol.MaxReceivePacketSize) b := make([]byte, 0, protocol.MaxReceivePacketSize)
return &b
} }
} }
...@@ -38,6 +38,8 @@ type client struct { ...@@ -38,6 +38,8 @@ type client struct {
version protocol.VersionNumber version protocol.VersionNumber
session packetHandler session packetHandler
logger utils.Logger
} }
var ( var (
...@@ -85,6 +87,14 @@ func Dial( ...@@ -85,6 +87,14 @@ func Dial(
} }
} }
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
clientConfig := populateClientConfig(config) clientConfig := populateClientConfig(config)
c := &client{ c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
...@@ -94,9 +104,10 @@ func Dial( ...@@ -94,9 +104,10 @@ func Dial(
config: clientConfig, config: clientConfig,
version: clientConfig.Versions[0], version: clientConfig.Versions[0],
versionNegotiationChan: make(chan struct{}), versionNegotiationChan: make(chan struct{}),
logger: utils.DefaultLogger,
} }
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
if err := c.dial(); err != nil { if err := c.dial(); err != nil {
return nil, err return nil, err
...@@ -132,6 +143,18 @@ func populateClientConfig(config *Config) *Config { ...@@ -132,6 +143,18 @@ func populateClientConfig(config *Config) *Config {
if maxReceiveConnectionFlowControlWindow == 0 { if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
} }
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
return &Config{ return &Config{
Versions: versions, Versions: versions,
...@@ -140,6 +163,8 @@ func populateClientConfig(config *Config) *Config { ...@@ -140,6 +163,8 @@ func populateClientConfig(config *Config) *Config {
RequestConnectionIDOmission: config.RequestConnectionIDOmission, RequestConnectionIDOmission: config.RequestConnectionIDOmission,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
KeepAlive: config.KeepAlive, KeepAlive: config.KeepAlive,
} }
} }
...@@ -171,12 +196,11 @@ func (c *client) dialTLS() error { ...@@ -171,12 +196,11 @@ func (c *client) dialTLS() error {
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: c.config.IdleTimeout, IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission, OmitConnectionID: c.config.RequestConnectionIDOmission,
// TODO(#523): make these values configurable MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient), MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
} }
csc := handshake.NewCryptoStreamConn(nil) csc := handshake.NewCryptoStreamConn(nil)
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version) extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient) mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil { if err != nil {
return err return err
...@@ -193,7 +217,7 @@ func (c *client) dialTLS() error { ...@@ -193,7 +217,7 @@ func (c *client) dialTLS() error {
if err != handshake.ErrCloseSessionForRetry { if err != handshake.ErrCloseSessionForRetry {
return err return err
} }
utils.Infof("Received a Retry packet. Recreating session.") c.logger.Infof("Received a Retry packet. Recreating session.")
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err return err
} }
...@@ -216,7 +240,7 @@ func (c *client) establishSecureConnection() error { ...@@ -216,7 +240,7 @@ func (c *client) establishSecureConnection() error {
go func() { go func() {
runErr = c.session.run() // returns as soon as the session is closed runErr = c.session.run() // returns as soon as the session is closed
close(errorChan) close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID) c.logger.Infof("Connection %x closed.", c.connectionID)
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion { if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
c.conn.Close() c.conn.Close()
} }
...@@ -245,7 +269,7 @@ func (c *client) listen() { ...@@ -245,7 +269,7 @@ func (c *client) listen() {
for { for {
var n int var n int
var addr net.Addr var addr net.Addr
data := getPacketBuffer() data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize] data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes // The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable // If it does, we only read a truncated packet, which will then end up undecryptable
...@@ -270,7 +294,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { ...@@ -270,7 +294,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
r := bytes.NewReader(packet) r := bytes.NewReader(packet)
hdr, err := wire.ParseHeaderSentByServer(r, c.version) hdr, err := wire.ParseHeaderSentByServer(r, c.version)
if err != nil { if err != nil {
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the header // drop this packet if we can't parse the header
return return
} }
...@@ -293,15 +317,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { ...@@ -293,15 +317,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
// check if the remote address and the connection ID match // check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID { if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
utils.Infof("Received a spoofed Public Reset. Ignoring.") c.logger.Infof("Received a spoofed Public Reset. Ignoring.")
return return
} }
pr, err := wire.ParsePublicReset(r) pr, err := wire.ParsePublicReset(r)
if err != nil { if err != nil {
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
return return
} }
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
return return
} }
...@@ -347,6 +371,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { ...@@ -347,6 +371,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
} }
} }
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok { if !ok {
return qerr.InvalidVersion return qerr.InvalidVersion
...@@ -362,7 +388,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { ...@@ -362,7 +388,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
if err != nil { if err != nil {
return err return err
} }
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) c.logger.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion) c.session.Close(errCloseSessionForNewVersion)
return nil return nil
} }
...@@ -379,6 +405,7 @@ func (c *client) createNewGQUICSession() (err error) { ...@@ -379,6 +405,7 @@ func (c *client) createNewGQUICSession() (err error) {
c.config, c.config,
c.initialVersion, c.initialVersion,
c.negotiatedVersions, c.negotiatedVersions,
c.logger,
) )
return err return err
} }
...@@ -398,6 +425,7 @@ func (c *client) createNewTLSSession( ...@@ -398,6 +425,7 @@ func (c *client) createNewTLSSession(
c.tls, c.tls,
paramsChan, paramsChan,
1, 1,
c.logger,
) )
return err return err
} }
...@@ -19,12 +19,14 @@ func main() { ...@@ -19,12 +19,14 @@ func main() {
flag.Parse() flag.Parse()
urls := flag.Args() urls := flag.Args()
logger := utils.DefaultLogger
if *verbose { if *verbose {
utils.SetLogLevel(utils.LogLevelDebug) logger.SetLogLevel(utils.LogLevelDebug)
} else { } else {
utils.SetLogLevel(utils.LogLevelInfo) logger.SetLogLevel(utils.LogLevelInfo)
} }
utils.SetLogTimeFormat("") logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions versions := protocol.SupportedVersions
if *tls { if *tls {
...@@ -42,21 +44,21 @@ func main() { ...@@ -42,21 +44,21 @@ func main() {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(urls)) wg.Add(len(urls))
for _, addr := range urls { for _, addr := range urls {
utils.Infof("GET %s", addr) logger.Infof("GET %s", addr)
go func(addr string) { go func(addr string) {
rsp, err := hclient.Get(addr) rsp, err := hclient.Get(addr)
if err != nil { if err != nil {
panic(err) panic(err)
} }
utils.Infof("Got response for %s: %#v", addr, rsp) logger.Infof("Got response for %s: %#v", addr, rsp)
body := &bytes.Buffer{} body := &bytes.Buffer{}
_, err = io.Copy(body, rsp.Body) _, err = io.Copy(body, rsp.Body)
if err != nil { if err != nil {
panic(err) panic(err)
} }
utils.Infof("Request Body:") logger.Infof("Request Body:")
utils.Infof("%s", body.Bytes()) logger.Infof("%s", body.Bytes())
wg.Done() wg.Done()
}(addr) }(addr)
} }
......
...@@ -91,7 +91,7 @@ func init() { ...@@ -91,7 +91,7 @@ func init() {
} }
} }
if err != nil { if err != nil {
utils.Infof("Error receiving upload: %#v", err) utils.DefaultLogger.Infof("Error receiving upload: %#v", err)
} }
} }
io.WriteString(w, `<html><body><form action="/demo/upload" method="post" enctype="multipart/form-data"> io.WriteString(w, `<html><body><form action="/demo/upload" method="post" enctype="multipart/form-data">
...@@ -126,12 +126,14 @@ func main() { ...@@ -126,12 +126,14 @@ func main() {
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)") tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse() flag.Parse()
logger := utils.DefaultLogger
if *verbose { if *verbose {
utils.SetLogLevel(utils.LogLevelDebug) logger.SetLogLevel(utils.LogLevelDebug)
} else { } else {
utils.SetLogLevel(utils.LogLevelInfo) logger.SetLogLevel(utils.LogLevelInfo)
} }
utils.SetLogTimeFormat("") logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions versions := protocol.SupportedVersions
if *tls { if *tls {
......
...@@ -46,6 +46,8 @@ type client struct { ...@@ -46,6 +46,8 @@ type client struct {
requestWriter *requestWriter requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response responses map[protocol.StreamID]chan *http.Response
logger utils.Logger
} }
var _ http.RoundTripper = &client{} var _ http.RoundTripper = &client{}
...@@ -75,6 +77,7 @@ func newClient( ...@@ -75,6 +77,7 @@ func newClient(
opts: opts, opts: opts,
headerErrored: make(chan struct{}), headerErrored: make(chan struct{}),
dialer: dialer, dialer: dialer,
logger: utils.DefaultLogger,
} }
} }
...@@ -95,7 +98,7 @@ func (c *client) dial() error { ...@@ -95,7 +98,7 @@ func (c *client) dial() error {
if err != nil { if err != nil {
return err return err
} }
c.requestWriter = newRequestWriter(c.headerStream) c.requestWriter = newRequestWriter(c.headerStream, c.logger)
go c.handleHeaderStream() go c.handleHeaderStream()
return nil return nil
} }
...@@ -108,7 +111,9 @@ func (c *client) handleHeaderStream() { ...@@ -108,7 +111,9 @@ func (c *client) handleHeaderStream() {
for err == nil { for err == nil {
err = c.readResponse(h2framer, decoder) err = c.readResponse(h2framer, decoder)
} }
utils.Debugf("Error handling header stream: %s", err) if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
c.logger.Debugf("Error handling header stream: %s", err)
}
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error()) c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request // stop all running request
close(c.headerErrored) close(c.headerErrored)
...@@ -202,6 +207,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -202,6 +207,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
bodySent = true bodySent = true
} }
ctx := req.Context()
for !(bodySent && receivedResponse) { for !(bodySent && receivedResponse) {
select { select {
case res = <-responseChan: case res = <-responseChan:
...@@ -214,8 +220,16 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -214,8 +220,16 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
case <-ctx.Done():
// error code 6 signals that stream was canceled
dataStream.CancelRead(6)
dataStream.CancelWrite(6)
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
return nil, ctx.Err()
case <-c.headerErrored: case <-c.headerErrored:
// an error occured on the header stream // an error occurred on the header stream
_ = c.CloseWithError(c.headerErr) _ = c.CloseWithError(c.headerErr)
return nil, c.headerErr return nil, c.headerErr
} }
......
...@@ -23,13 +23,16 @@ type requestWriter struct { ...@@ -23,13 +23,16 @@ type requestWriter struct {
henc *hpack.Encoder henc *hpack.Encoder
hbuf bytes.Buffer // HPACK encoder writes into this hbuf bytes.Buffer // HPACK encoder writes into this
logger utils.Logger
} }
const defaultUserAgent = "quic-go" const defaultUserAgent = "quic-go"
func newRequestWriter(headerStream quic.Stream) *requestWriter { func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
rw := &requestWriter{ rw := &requestWriter{
headerStream: headerStream, headerStream: headerStream,
logger: logger,
} }
rw.henc = hpack.NewEncoder(&rw.hbuf) rw.henc = hpack.NewEncoder(&rw.hbuf)
return rw return rw
...@@ -76,9 +79,8 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra ...@@ -76,9 +79,8 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
if !validPseudoPath(path) { if !validPseudoPath(path) {
if req.URL.Opaque != "" { if req.URL.Opaque != "" {
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return nil, fmt.Errorf("invalid request :path %q", orig)
} }
return nil, fmt.Errorf("invalid request :path %q", orig)
} }
} }
} }
...@@ -157,7 +159,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra ...@@ -157,7 +159,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
} }
func (w *requestWriter) writeHeader(name, value string) { func (w *requestWriter) writeHeader(name, value string) {
utils.Debugf("http2: Transport encoding header %q = %q", name, value) w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
} }
......
...@@ -3,7 +3,6 @@ package h2quic ...@@ -3,7 +3,6 @@ package h2quic
import ( import (
"bytes" "bytes"
"errors" "errors"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/textproto" "net/textproto"
...@@ -16,7 +15,7 @@ import ( ...@@ -16,7 +15,7 @@ import (
// copied from net/http2/transport.go // copied from net/http2/transport.go
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) var noBody = ioutil.NopCloser(bytes.NewReader(nil))
// from the handleResponse function // from the handleResponse function
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
...@@ -33,16 +32,7 @@ func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { ...@@ -33,16 +32,7 @@ func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
return nil, errors.New("malformed non-numeric status pseudo header") return nil, errors.New("malformed non-numeric status pseudo header")
} }
if statusCode == 100 { // TODO: handle statusCode == 100
// TODO: handle this
// traceGot100Continue(cs.trace)
// if cs.on100 != nil {
// cs.on100() // forces any write delay timer to fire
// }
// cs.pastHeaders = false // do it all again
// return nil, nil
}
header := make(http.Header) header := make(http.Header)
res := &http.Response{ res := &http.Response{
...@@ -78,13 +68,7 @@ func setLength(res *http.Response, isHead, streamEnded bool) *http.Response { ...@@ -78,13 +68,7 @@ func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
if clens := res.Header["Content-Length"]; len(clens) == 1 { if clens := res.Header["Content-Length"]; len(clens) == 1 {
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
res.ContentLength = clen64 res.ContentLength = clen64
} else {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
} }
} else if len(clens) > 1 {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
} }
} }
return res return res
......
...@@ -24,15 +24,24 @@ type responseWriter struct { ...@@ -24,15 +24,24 @@ type responseWriter struct {
header http.Header header http.Header
status int // status code passed to WriteHeader status int // status code passed to WriteHeader
headerWritten bool headerWritten bool
logger utils.Logger
} }
func newResponseWriter(headerStream quic.Stream, headerStreamMutex *sync.Mutex, dataStream quic.Stream, dataStreamID protocol.StreamID) *responseWriter { func newResponseWriter(
headerStream quic.Stream,
headerStreamMutex *sync.Mutex,
dataStream quic.Stream,
dataStreamID protocol.StreamID,
logger utils.Logger,
) *responseWriter {
return &responseWriter{ return &responseWriter{
header: http.Header{}, header: http.Header{},
headerStream: headerStream, headerStream: headerStream,
headerStreamMutex: headerStreamMutex, headerStreamMutex: headerStreamMutex,
dataStream: dataStream, dataStream: dataStream,
dataStreamID: dataStreamID, dataStreamID: dataStreamID,
logger: logger,
} }
} }
...@@ -57,7 +66,7 @@ func (w *responseWriter) WriteHeader(status int) { ...@@ -57,7 +66,7 @@ func (w *responseWriter) WriteHeader(status int) {
} }
} }
utils.Infof("Responding with %d", status) w.logger.Infof("Responding with %d", status)
w.headerStreamMutex.Lock() w.headerStreamMutex.Lock()
defer w.headerStreamMutex.Unlock() defer w.headerStreamMutex.Unlock()
h2framer := http2.NewFramer(w.headerStream, nil) h2framer := http2.NewFramer(w.headerStream, nil)
...@@ -67,7 +76,7 @@ func (w *responseWriter) WriteHeader(status int) { ...@@ -67,7 +76,7 @@ func (w *responseWriter) WriteHeader(status int) {
BlockFragment: headers.Bytes(), BlockFragment: headers.Bytes(),
}) })
if err != nil { if err != nil {
utils.Errorf("could not write h2 header: %s", err.Error()) w.logger.Errorf("could not write h2 header: %s", err.Error())
} }
} }
......
...@@ -53,6 +53,8 @@ type Server struct { ...@@ -53,6 +53,8 @@ type Server struct {
closed bool closed bool
supportedVersionsAsString string supportedVersionsAsString string
logger utils.Logger // will be set by Server.serveImpl()
} }
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
...@@ -88,6 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { ...@@ -88,6 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil { if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server") return errors.New("use of h2quic.Server without http.Server")
} }
s.logger = utils.DefaultLogger
s.listenerMutex.Lock() s.listenerMutex.Lock()
if s.closed { if s.closed {
s.listenerMutex.Unlock() s.listenerMutex.Unlock()
...@@ -138,7 +141,7 @@ func (s *Server) handleHeaderStream(session streamCreator) { ...@@ -138,7 +141,7 @@ func (s *Server) handleHeaderStream(session streamCreator) {
// In this case, the session has already logged the error, so we don't // In this case, the session has already logged the error, so we don't
// need to log it again. // need to log it again.
if _, ok := err.(*qerr.QuicError); !ok { if _, ok := err.(*qerr.QuicError); !ok {
utils.Errorf("error handling h2 request: %s", err.Error()) s.logger.Errorf("error handling h2 request: %s", err.Error())
} }
session.Close(err) session.Close(err)
return return
...@@ -160,7 +163,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, ...@@ -160,7 +163,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
} }
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment()) headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
if err != nil { if err != nil {
utils.Errorf("invalid http2 headers encoding: %s", err.Error()) s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
return err return err
} }
...@@ -169,10 +172,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, ...@@ -169,10 +172,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
return err return err
} }
if utils.Debug() { if s.logger.Debug() {
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID) s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
} else { } else {
utils.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
} }
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID)) dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
...@@ -201,7 +204,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, ...@@ -201,7 +204,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
req.RemoteAddr = session.RemoteAddr().String() req.RemoteAddr = session.RemoteAddr().String()
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
handler := s.Handler handler := s.Handler
if handler == nil { if handler == nil {
...@@ -215,7 +218,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, ...@@ -215,7 +218,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
const size = 64 << 10 const size = 64 << 10
buf := make([]byte, size) buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)] buf = buf[:runtime.Stack(buf, false)]
utils.Errorf("http: panic serving: %v\n%s", p, buf) s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
panicked = true panicked = true
} }
}() }()
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
// Connection is a UDP connection // Connection is a UDP connection
...@@ -43,6 +44,8 @@ func (d Direction) String() string { ...@@ -43,6 +44,8 @@ func (d Direction) String() string {
} }
} }
// Is says if one direction matches another direction.
// For example, incoming matches both incoming and both, but not outgoing.
func (d Direction) Is(dir Direction) bool { func (d Direction) Is(dir Direction) bool {
if d == DirectionBoth || dir == DirectionBoth { if d == DirectionBoth || dir == DirectionBoth {
return true return true
...@@ -92,6 +95,8 @@ type QuicProxy struct { ...@@ -92,6 +95,8 @@ type QuicProxy struct {
// Mapping from client addresses (as host:port) to connection // Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection clientDict map[string]*connection
logger utils.Logger
} }
// NewQuicProxy creates a new UDP proxy // NewQuicProxy creates a new UDP proxy
...@@ -129,14 +134,23 @@ func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*Qu ...@@ -129,14 +134,23 @@ func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*Qu
dropPacket: packetDropper, dropPacket: packetDropper,
delayPacket: packetDelayer, delayPacket: packetDelayer,
version: version, version: version,
logger: utils.DefaultLogger,
} }
p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
go p.runProxy() go p.runProxy()
return &p, nil return &p, nil
} }
// Close stops the UDP Proxy // Close stops the UDP Proxy
func (p *QuicProxy) Close() error { func (p *QuicProxy) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
for _, c := range p.clientDict {
if err := c.ServerConn.Close(); err != nil {
return err
}
}
return p.conn.Close() return p.conn.Close()
} }
...@@ -189,19 +203,27 @@ func (p *QuicProxy) runProxy() error { ...@@ -189,19 +203,27 @@ func (p *QuicProxy) runProxy() error {
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1) packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
if p.dropPacket(DirectionIncoming, packetCount) { if p.dropPacket(DirectionIncoming, packetCount) {
if p.logger.Debug() {
p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
}
continue continue
} }
// Send the packet to the server // Send the packet to the server
delay := p.delayPacket(DirectionIncoming, packetCount) delay := p.delayPacket(DirectionIncoming, packetCount)
if delay != 0 { if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
}
time.AfterFunc(delay, func() { time.AfterFunc(delay, func() {
// TODO: handle error // TODO: handle error
_, _ = conn.ServerConn.Write(raw) _, _ = conn.ServerConn.Write(raw)
}) })
} else { } else {
_, err := conn.ServerConn.Write(raw) if p.logger.Debug() {
if err != nil { p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
}
if _, err := conn.ServerConn.Write(raw); err != nil {
return err return err
} }
} }
...@@ -221,18 +243,26 @@ func (p *QuicProxy) runConnection(conn *connection) error { ...@@ -221,18 +243,26 @@ func (p *QuicProxy) runConnection(conn *connection) error {
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1) packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
if p.dropPacket(DirectionOutgoing, packetCount) { if p.dropPacket(DirectionOutgoing, packetCount) {
if p.logger.Debug() {
p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
}
continue continue
} }
delay := p.delayPacket(DirectionOutgoing, packetCount) delay := p.delayPacket(DirectionOutgoing, packetCount)
if delay != 0 { if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
}
time.AfterFunc(delay, func() { time.AfterFunc(delay, func() {
// TODO: handle error // TODO: handle error
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) _, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
}) })
} else { } else {
_, err := p.conn.WriteToUDP(raw, conn.ClientAddr) if p.logger.Debug() {
if err != nil { p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
return err return err
} }
} }
......
...@@ -30,7 +30,7 @@ var _ = BeforeEach(func() { ...@@ -30,7 +30,7 @@ var _ = BeforeEach(func() {
logFile, err = os.Create(logFileName) logFile, err = os.Create(logFileName)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
log.SetOutput(logFile) log.SetOutput(logFile)
utils.SetLogLevel(utils.LogLevelDebug) utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
} }
}) })
......
...@@ -22,7 +22,9 @@ const ( ...@@ -22,7 +22,9 @@ const (
) )
var ( var (
// PRData contains dataLen bytes of pseudo-random data.
PRData = GeneratePRData(dataLen) PRData = GeneratePRData(dataLen)
// PRDataLong contains dataLenLong bytes of pseudo-random data.
PRDataLong = GeneratePRData(dataLenLong) PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server server *h2quic.Server
...@@ -105,11 +107,13 @@ func StartQuicServer(versions []protocol.VersionNumber) { ...@@ -105,11 +107,13 @@ func StartQuicServer(versions []protocol.VersionNumber) {
}() }()
} }
// StopQuicServer stops the h2quic.Server.
func StopQuicServer() { func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred()) Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed()) Eventually(stoppedServing).Should(BeClosed())
} }
// Port returns the UDP port of the QUIC server.
func Port() string { func Port() string {
return port return port
} }
...@@ -16,6 +16,9 @@ type StreamID = protocol.StreamID ...@@ -16,6 +16,9 @@ type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number. // A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber type VersionNumber = protocol.VersionNumber
// VersionGQUIC39 is gQUIC version 39.
const VersionGQUIC39 = protocol.Version39
// A Cookie can be used to verify the ownership of the client address. // A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie type Cookie = handshake.Cookie
...@@ -113,15 +116,25 @@ type StreamError interface { ...@@ -113,15 +116,25 @@ type StreamError interface {
// A Session is a QUIC connection between two peers. // A Session is a QUIC connection between two peers.
type Session interface { type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available. // AcceptStream returns the next stream opened by the peer, blocking until one is available.
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
AcceptStream() (Stream, error) AcceptStream() (Stream, error)
// OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached. // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
// New streams always have the smallest possible stream ID. AcceptUniStream() (ReceiveStream, error)
// TODO: Enable testing for the special error // OpenStream opens a new bidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
// TODO(#1152): Enable testing for the special error
OpenStream() (Stream, error) OpenStream() (Stream, error)
// OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened. // OpenStreamSync opens a new bidirectional QUIC stream.
// It always picks the smallest possible stream ID. // It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenStreamSync() (Stream, error) OpenStreamSync() (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// TODO(#1152): Enable testing for the special error
OpenUniStream() (SendStream, error)
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenUniStreamSync() (SendStream, error)
// LocalAddr returns the local address. // LocalAddr returns the local address.
LocalAddr() net.Addr LocalAddr() net.Addr
// RemoteAddr returns the address of the peer. // RemoteAddr returns the address of the peer.
...@@ -166,6 +179,17 @@ type Config struct { ...@@ -166,6 +179,17 @@ type Config struct {
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data. // MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client. // If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow uint64 MaxReceiveConnectionFlowControlWindow uint64
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingStreams int
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// This value doesn't have any effect in Google QUIC.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingUniStreams int
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive. // KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool KeepAlive bool
} }
......
package ackhandler
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet
...@@ -10,15 +10,13 @@ import ( ...@@ -10,15 +10,13 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets // SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface { type SentPacketHandler interface {
// SentPacket may modify the packet // SentPacket may modify the packet
SentPacket(packet *Packet) error SentPacket(packet *Packet)
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete() SetHandshakeComplete()
// SendingAllowed says if a packet can be sent. // The SendMode determines if and what kind of packets can be sent.
// Sending packets might not be possible because: SendMode() SendMode
// * we're congestion limited
// * we're tracking the maximum number of sent packets
SendingAllowed() bool
// TimeUntilSend is the time when the next packet should be sent. // TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets. // It is used for pacing packets.
TimeUntilSend() time.Time TimeUntilSend() time.Time
...@@ -32,10 +30,10 @@ type SentPacketHandler interface { ...@@ -32,10 +30,10 @@ type SentPacketHandler interface {
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() (packet *Packet) DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
GetAlarmTimeout() time.Time GetAlarmTimeout() time.Time
OnAlarm() OnAlarm() error
} }
// ReceivedPacketHandler handles ACKs needed to send for incoming packets // ReceivedPacketHandler handles ACKs needed to send for incoming packets
......
...@@ -8,28 +8,22 @@ import ( ...@@ -8,28 +8,22 @@ import (
) )
// A Packet is a packet // A Packet is a packet
// +gen linkedlist
type Packet struct { type Packet struct {
PacketNumber protocol.PacketNumber PacketNumber protocol.PacketNumber
PacketType protocol.PacketType
Frames []wire.Frame Frames []wire.Frame
Length protocol.ByteCount Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
sendTime time.Time
}
// GetFramesForRetransmission gets all the frames for retransmission // There are two reasons why a packet cannot be retransmitted:
func (p *Packet) GetFramesForRetransmission() []wire.Frame { // * it was already retransmitted
var fs []wire.Frame // * this packet is a retransmission, and we already received an ACK for the original packet
for _, frame := range p.Frames { canBeRetransmitted bool
switch frame.(type) { includedInBytesInFlight bool
case *wire.AckFrame: retransmittedAs []protocol.PacketNumber
continue isRetransmission bool // we need a separate bool here because 0 is a valid packet number
case *wire.StopWaitingFrame: retransmissionOf protocol.PacketNumber
continue
}
fs = append(fs, frame)
}
return fs
} }
// Generated by: main // This file was automatically generated by genny.
// TypeWriter: linkedlist // Any changes will be lost if this file is regenerated.
// Directive: +gen on Packet // see https://github.com/cheekybits/genny
package ackhandler package ackhandler
// List is a modification of http://golang.org/pkg/container/list/ // Linked list implementation from the Go standard library.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// PacketElement is an element of a linked list. // PacketElement is an element of a linked list.
type PacketElement struct { type PacketElement struct {
...@@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement { ...@@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement {
return nil return nil
} }
// PacketList represents a doubly linked list. // PacketList is a linked list of Packets.
// The zero value for PacketList is an empty list ready to use.
type PacketList struct { type PacketList struct {
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element len int // current list length excluding (this) sentinel element
...@@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() } ...@@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() }
// The complexity is O(1). // The complexity is O(1).
func (l *PacketList) Len() int { return l.len } func (l *PacketList) Len() int { return l.len }
// Front returns the first element of list l or nil. // Front returns the first element of list l or nil if the list is empty.
func (l *PacketList) Front() *PacketElement { func (l *PacketList) Front() *PacketElement {
if l.len == 0 { if l.len == 0 {
return nil return nil
...@@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement { ...@@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement {
return l.root.next return l.root.next
} }
// Back returns the last element of list l or nil. // Back returns the last element of list l or nil if the list is empty.
func (l *PacketList) Back() *PacketElement { func (l *PacketList) Back() *PacketElement {
if l.len == 0 { if l.len == 0 {
return nil return nil
...@@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement { ...@@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement {
return l.root.prev return l.root.prev
} }
// lazyInit lazily initializes a zero PacketList value. // lazyInit lazily initializes a zero List value.
func (l *PacketList) lazyInit() { func (l *PacketList) lazyInit() {
if l.root.next == nil { if l.root.next == nil {
l.Init() l.Init()
...@@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement { ...@@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
return e return e
} }
// insertValue is a convenience wrapper for insert(&PacketElement{Value: v}, at). // insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement { func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
return l.insert(&PacketElement{Value: v}, at) return l.insert(&PacketElement{Value: v}, at)
} }
...@@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement { ...@@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement {
// Remove removes e from l if e is an element of list l. // Remove removes e from l if e is an element of list l.
// It returns the element value e.Value. // It returns the element value e.Value.
// The element must not be nil.
func (l *PacketList) Remove(e *PacketElement) Packet { func (l *PacketList) Remove(e *PacketElement) Packet {
if e.list == l { if e.list == l {
// if e.list == l, l must have been initialized when e was inserted // if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero PacketElement) and l.remove will crash // in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e) l.remove(e)
} }
return e.Value return e.Value
...@@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement { ...@@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement {
// InsertBefore inserts a new element e with value v immediately before mark and returns e. // InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified. // If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement { func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l { if mark.list != l {
return nil return nil
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev) return l.insertValue(v, mark.prev)
} }
// InsertAfter inserts a new element e with value v immediately after mark and returns e. // InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified. // If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement { func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l { if mark.list != l {
return nil return nil
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
return l.insertValue(v, mark) return l.insertValue(v, mark)
} }
// MoveToFront moves element e to the front of list l. // MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified. // If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToFront(e *PacketElement) { func (l *PacketList) MoveToFront(e *PacketElement) {
if e.list != l || l.root.next == e { if e.list != l || l.root.next == e {
return return
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root) l.insert(l.remove(e), &l.root)
} }
// MoveToBack moves element e to the back of list l. // MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified. // If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToBack(e *PacketElement) { func (l *PacketList) MoveToBack(e *PacketElement) {
if e.list != l || l.root.prev == e { if e.list != l || l.root.prev == e {
return return
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev) l.insert(l.remove(e), l.root.prev)
} }
// MoveBefore moves element e to its new position before mark. // MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified. // If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveBefore(e, mark *PacketElement) { func (l *PacketList) MoveBefore(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l { if e.list != l || e == mark || mark.list != l {
return return
...@@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) { ...@@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) {
} }
// MoveAfter moves element e to its new position after mark. // MoveAfter moves element e to its new position after mark.
// If e is not an element of l, or e == mark, the list is not modified. // If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveAfter(e, mark *PacketElement) { func (l *PacketList) MoveAfter(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l { if e.list != l || e == mark || mark.list != l {
return return
...@@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) { ...@@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) {
} }
// PushBackList inserts a copy of an other list at the back of list l. // PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. // The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushBackList(other *PacketList) { func (l *PacketList) PushBackList(other *PacketList) {
l.lazyInit() l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
...@@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) { ...@@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) {
} }
// PushFrontList inserts a copy of an other list at the front of list l. // PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. // The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushFrontList(other *PacketList) { func (l *PacketList) PushFrontList(other *PacketList) {
l.lazyInit() l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
......
...@@ -3,7 +3,9 @@ package ackhandler ...@@ -3,7 +3,9 @@ package ackhandler
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
) )
...@@ -15,6 +17,7 @@ type receivedPacketHandler struct { ...@@ -15,6 +17,7 @@ type receivedPacketHandler struct {
packetHistory *receivedPacketHistory packetHistory *receivedPacketHistory
ackSendDelay time.Duration ackSendDelay time.Duration
rttStats *congestion.RTTStats
packetsReceivedSinceLastAck int packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int retransmittablePacketsReceivedSinceLastAck int
...@@ -25,29 +28,54 @@ type receivedPacketHandler struct { ...@@ -25,29 +28,54 @@ type receivedPacketHandler struct {
version protocol.VersionNumber version protocol.VersionNumber
} }
const (
// maximum delay that can be applied to an ACK for a retransmittable packet
ackSendDelay = 25 * time.Millisecond
// initial maximum number of retransmittable packets received before sending an ack.
initialRetransmittablePacketsBeforeAck = 2
// number of retransmittable that an ACK is sent for
retransmittablePacketsBeforeAck = 10
// 1/5 RTT delay when doing ack decimation
ackDecimationDelay = 1.0 / 4
// 1/8 RTT delay when doing ack decimation
shortAckDecimationDelay = 1.0 / 8
// Minimum number of packets received before ack decimation is enabled.
// This intends to avoid the beginning of slow start, when CWNDs may be
// rapidly increasing.
minReceivedBeforeAckDecimation = 100
// Maximum number of packets to ack immediately after a missing packet for
// fast retransmission to kick in at the sender. This limit is created to
// reduce the number of acks sent that have no benefit for fast retransmission.
// Set to the number of nacks needed for fast retransmit plus one for protection
// against an ack loss
maxPacketsAfterNewMissing = 4
)
// NewReceivedPacketHandler creates a new receivedPacketHandler // NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler { func NewReceivedPacketHandler(rttStats *congestion.RTTStats, version protocol.VersionNumber) ReceivedPacketHandler {
return &receivedPacketHandler{ return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(), packetHistory: newReceivedPacketHistory(),
ackSendDelay: protocol.AckSendDelay, ackSendDelay: ackSendDelay,
rttStats: rttStats,
version: version, version: version,
} }
} }
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error { func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber < h.ignoreBelow {
return nil
}
isMissing := h.isMissing(packetNumber)
if packetNumber > h.largestObserved { if packetNumber > h.largestObserved {
h.largestObserved = packetNumber h.largestObserved = packetNumber
h.largestObservedReceivedTime = rcvTime h.largestObservedReceivedTime = rcvTime
} }
if packetNumber < h.ignoreBelow {
return nil
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil { if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err return err
} }
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck) h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
return nil return nil
} }
...@@ -58,35 +86,68 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) { ...@@ -58,35 +86,68 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
h.packetHistory.DeleteBelow(p) h.packetHistory.DeleteBelow(p)
} }
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) { // isMissing says if a packet was reported missing in the last ACK.
h.packetsReceivedSinceLastAck++ func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil {
if shouldInstigateAck { return false
h.retransmittablePacketsReceivedSinceLastAck++
} }
return p < h.lastAck.LargestAcked && !h.lastAck.AcksPacket(p)
}
// always ack the first packet func (h *receivedPacketHandler) hasNewMissingPackets() bool {
if h.lastAck == nil { if h.lastAck == nil {
h.ackQueued = true return false
} }
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.First >= h.lastAck.LargestAcked && highestRange.Len() <= maxPacketsAfterNewMissing
}
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK // maybeQueueAck queues an ACK, if necessary.
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket() // It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked { // in ACK_DECIMATION_WITH_REORDERING mode.
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
h.packetsReceivedSinceLastAck++
// always ack the first packet
if h.lastAck == nil {
h.ackQueued = true h.ackQueued = true
return
} }
// check if a new missing range above the previously was created // Send an ACK if this packet was reported missing in an ACK sent before.
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked { // Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
h.ackQueued = true h.ackQueued = true
} }
if !h.ackQueued && shouldInstigateAck { if !h.ackQueued && shouldInstigateAck {
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck { h.retransmittablePacketsReceivedSinceLastAck++
if packetNumber > minReceivedBeforeAckDecimation {
// ack up to 10 packets at once
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
h.ackQueued = true h.ackQueued = true
} else if h.ackAlarm.IsZero() {
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
h.ackAlarm = rcvTime.Add(ackDelay)
}
} else { } else {
if h.ackAlarm.IsZero() { // send an ACK every 2 retransmittable packets
h.ackAlarm = rcvTime.Add(h.ackSendDelay) if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
h.ackAlarm = rcvTime.Add(ackSendDelay)
}
}
// If there are new missing packets to report, set a short timer to send an ACK.
if h.hasNewMissingPackets() {
// wait the minimum of 1/8 min RTT and the existing ack time
ackDelay := float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay)
ackTime := rcvTime.Add(time.Duration(ackDelay))
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
h.ackAlarm = ackTime
} }
} }
} }
...@@ -118,7 +179,6 @@ func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame { ...@@ -118,7 +179,6 @@ func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
h.ackQueued = false h.ackQueued = false
h.packetsReceivedSinceLastAck = 0 h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0 h.retransmittablePacketsReceivedSinceLastAck = 0
return ack return ack
} }
......
package ackhandler
import "fmt"
// The SendMode says what kind of packets can be sent.
type SendMode uint8
const (
// SendNone means that no packets should be sent
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendRetransmission means that retransmissions should be sent
SendRetransmission
// SendRTO means that an RTO probe packet should be sent
SendRTO
// SendAny packet should be sent
SendAny
)
func (s SendMode) String() string {
switch s {
case SendNone:
return "none"
case SendAck:
return "ack"
case SendRetransmission:
return "retransmission"
case SendRTO:
return "rto"
case SendAny:
return "any"
default:
return fmt.Sprintf("invalid send mode: %d", s)
}
}
package ackhandler
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
firstOutstanding *PacketElement
}
func newSentPacketHistory() *sentPacketHistory {
return &sentPacketHistory{
packetList: NewPacketList(),
packetMap: make(map[protocol.PacketNumber]*PacketElement),
}
}
func (h *sentPacketHistory) SentPacket(p *Packet) {
h.sentPacketImpl(p)
}
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
if h.firstOutstanding == nil {
h.firstOutstanding = el
}
return el
}
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
retransmission, ok := h.packetMap[retransmissionOf]
// The retransmitted packet is not present anymore.
// This can happen if it was acked in between dequeueing of the retransmission and sending.
// Just treat the retransmissions as normal packets.
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
if !ok {
for _, packet := range packets {
h.sentPacketImpl(packet)
}
return
}
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
for i, packet := range packets {
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
el := h.sentPacketImpl(packet)
el.Value.isRetransmission = true
el.Value.retransmissionOf = retransmissionOf
}
}
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
if el, ok := h.packetMap[p]; ok {
return &el.Value
}
return nil
}
// Iterate iterates through all packets.
// The callback must not modify the history.
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
cont := true
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
var err error
cont, err = cb(&el.Value)
if err != nil {
return err
}
}
return nil
}
// FirstOutStanding returns the first outstanding packet.
// It must not be modified (e.g. retransmitted).
// Use DequeueFirstPacketForRetransmission() to retransmit it.
func (h *sentPacketHistory) FirstOutstanding() *Packet {
if h.firstOutstanding == nil {
return nil
}
return &h.firstOutstanding.Value
}
// QueuePacketForRetransmission marks a packet for retransmission.
// A packet can only be queued once.
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
el, ok := h.packetMap[pn]
if !ok {
return fmt.Errorf("sent packet history: packet %d not found", pn)
}
el.Value.canBeRetransmitted = false
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
return nil
}
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
// This is necessary every time the first outstanding packet is deleted or retransmitted.
func (h *sentPacketHistory) readjustFirstOutstanding() {
el := h.firstOutstanding.Next()
for el != nil && !el.Value.canBeRetransmitted {
el = el.Next()
}
h.firstOutstanding = el
}
func (h *sentPacketHistory) Len() int {
return len(h.packetMap)
}
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
el, ok := h.packetMap[p]
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
}
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
h.packetList.Remove(el)
delete(h.packetMap, p)
return nil
}
...@@ -292,11 +292,3 @@ func (c *cubicSender) OnConnectionMigration() { ...@@ -292,11 +292,3 @@ func (c *cubicSender) OnConnectionMigration() {
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) { func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
c.slowStartLargeReduction = enabled c.slowStartLargeReduction = enabled
} }
// RetransmissionDelay gives the time to retransmission
func (c *cubicSender) RetransmissionDelay() time.Duration {
if c.rttStats.SmoothedRTT() == 0 {
return 0
}
return c.rttStats.SmoothedRTT() + c.rttStats.MeanDeviation()*4
}
...@@ -17,7 +17,6 @@ type SendAlgorithm interface { ...@@ -17,7 +17,6 @@ type SendAlgorithm interface {
SetNumEmulatedConnections(n int) SetNumEmulatedConnections(n int)
OnRetransmissionTimeout(packetsRetransmitted bool) OnRetransmissionTimeout(packetsRetransmitted bool)
OnConnectionMigration() OnConnectionMigration()
RetransmissionDelay() time.Duration
// Experiments // Experiments
SetSlowStartLargeReduction(enabled bool) SetSlowStartLargeReduction(enabled bool)
......
...@@ -84,7 +84,6 @@ func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) { ...@@ -84,7 +84,6 @@ func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) {
// UpdateRTT updates the RTT based on a new sample. // UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == utils.InfDuration || sendDelta <= 0 { if sendDelta == utils.InfDuration || sendDelta <= 0 {
utils.Debugf("Ignoring measured sendDelta, because it's is either infinite, zero, or negative: %d", sendDelta/time.Microsecond)
return return
} }
......
...@@ -55,28 +55,28 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) { ...@@ -55,28 +55,28 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
return cert.Certificate[0], nil return cert.Certificate[0], nil
} }
func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
c := cc.config conf := c.config
c, err := maybeGetConfigForClient(c, sni) conf, err := maybeGetConfigForClient(conf, sni)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// The rest of this function is mostly copied from crypto/tls.getCertificate // The rest of this function is mostly copied from crypto/tls.getCertificate
if c.GetCertificate != nil { if conf.GetCertificate != nil {
cert, err := c.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if cert != nil || err != nil { if cert != nil || err != nil {
return cert, err return cert, err
} }
} }
if len(c.Certificates) == 0 { if len(conf.Certificates) == 0 {
return nil, errNoMatchingCertificate return nil, errNoMatchingCertificate
} }
if len(c.Certificates) == 1 || c.NameToCertificate == nil { if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
// There's only one choice, so no point doing any work. // There's only one choice, so no point doing any work.
return &c.Certificates[0], nil return &conf.Certificates[0], nil
} }
name := strings.ToLower(sni) name := strings.ToLower(sni)
...@@ -84,7 +84,7 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { ...@@ -84,7 +84,7 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
name = name[:len(name)-1] name = name[:len(name)-1]
} }
if cert, ok := c.NameToCertificate[name]; ok { if cert, ok := conf.NameToCertificate[name]; ok {
return cert, nil return cert, nil
} }
...@@ -94,13 +94,13 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { ...@@ -94,13 +94,13 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
for i := range labels { for i := range labels {
labels[i] = "*" labels[i] = "*"
candidate := strings.Join(labels, ".") candidate := strings.Join(labels, ".")
if cert, ok := c.NameToCertificate[candidate]; ok { if cert, ok := conf.NameToCertificate[candidate]; ok {
return cert, nil return cert, nil
} }
} }
// If nothing matches, return the first certificate. // If nothing matches, return the first certificate.
return &c.Certificates[0], nil return &conf.Certificates[0], nil
} }
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) { func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
......
...@@ -21,10 +21,6 @@ func NewCurve25519KEX() (KeyExchange, error) { ...@@ -21,10 +21,6 @@ func NewCurve25519KEX() (KeyExchange, error) {
if _, err := rand.Read(c.secret[:]); err != nil { if _, err := rand.Read(c.secret[:]); err != nil {
return nil, errors.New("Curve25519: could not create private key") return nil, errors.New("Curve25519: could not create private key")
} }
// See https://cr.yp.to/ecdh.html
c.secret[0] &= 248
c.secret[31] &= 127
c.secret[31] |= 64
curve25519.ScalarBaseMult(&c.public, &c.secret) curve25519.ScalarBaseMult(&c.public, &c.secret)
return c, nil return c, nil
} }
......
package crypto package crypto
import ( import (
"crypto"
"encoding/binary"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
const ( const (
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret" clientExporterLabel = "EXPORTER-QUIC client 1rtt"
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret" serverExporterLabel = "EXPORTER-QUIC server 1rtt"
) )
// A TLSExporter gets the negotiated ciphersuite and computes exporter // A TLSExporter gets the negotiated ciphersuite and computes exporter
...@@ -16,6 +19,14 @@ type TLSExporter interface { ...@@ -16,6 +19,14 @@ type TLSExporter interface {
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
} }
func qhkdfExpand(secret []byte, label string, length int) []byte {
qlabel := make([]byte, 2+1+5+len(label))
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
qlabel[2] = uint8(5 + len(label))
copy(qlabel[3:], []byte("QUIC "+label))
return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length)
}
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance // DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) { func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
var myLabel, otherLabel string var myLabel, otherLabel string
...@@ -43,7 +54,7 @@ func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) ...@@ -43,7 +54,7 @@ func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen) key = qhkdfExpand(secret, "key", cs.KeyLen)
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen) iv = qhkdfExpand(secret, "iv", cs.IvLen)
return key, iv, nil return key, iv, nil
} }
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39} var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) { func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
clientSecret, serverSecret := computeSecrets(connectionID) clientSecret, serverSecret := computeSecrets(connectionID)
...@@ -31,14 +31,14 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec ...@@ -31,14 +31,14 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) { func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
connID := make([]byte, 8) connID := make([]byte, 8)
binary.BigEndian.PutUint64(connID, uint64(connectionID)) binary.BigEndian.PutUint64(connID, uint64(connectionID))
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID) handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size()) clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size()) serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
return return
} }
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) { func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16) key = qhkdfExpand(secret, "key", 16)
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12) iv = qhkdfExpand(secret, "iv", 12)
return return
} }
...@@ -25,6 +25,8 @@ type baseFlowController struct { ...@@ -25,6 +25,8 @@ type baseFlowController struct {
epochStartTime time.Time epochStartTime time.Time
epochStartOffset protocol.ByteCount epochStartOffset protocol.ByteCount
rttStats *congestion.RTTStats rttStats *congestion.RTTStats
logger utils.Logger
} }
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
......
...@@ -22,6 +22,7 @@ func NewConnectionFlowController( ...@@ -22,6 +22,7 @@ func NewConnectionFlowController(
receiveWindow protocol.ByteCount, receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount,
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger,
) ConnectionFlowController { ) ConnectionFlowController {
return &connectionFlowController{ return &connectionFlowController{
baseFlowController: baseFlowController{ baseFlowController: baseFlowController{
...@@ -29,6 +30,7 @@ func NewConnectionFlowController( ...@@ -29,6 +30,7 @@ func NewConnectionFlowController(
receiveWindow: receiveWindow, receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow, receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow, maxReceiveWindowSize: maxReceiveWindow,
logger: logger,
}, },
} }
} }
...@@ -65,7 +67,7 @@ func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { ...@@ -65,7 +67,7 @@ func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
oldWindowSize := c.receiveWindowSize oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate() offset := c.baseFlowController.getWindowUpdate()
if oldWindowSize < c.receiveWindowSize { if oldWindowSize < c.receiveWindowSize {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
} }
c.mutex.Unlock() c.mutex.Unlock()
return offset return offset
......
...@@ -31,6 +31,7 @@ func NewStreamFlowController( ...@@ -31,6 +31,7 @@ func NewStreamFlowController(
maxReceiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount, initialSendWindow protocol.ByteCount,
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger,
) StreamFlowController { ) StreamFlowController {
return &streamFlowController{ return &streamFlowController{
streamID: streamID, streamID: streamID,
...@@ -42,6 +43,7 @@ func NewStreamFlowController( ...@@ -42,6 +43,7 @@ func NewStreamFlowController(
receiveWindowSize: receiveWindow, receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow, maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow, sendWindow: initialSendWindow,
logger: logger,
}, },
} }
} }
...@@ -137,7 +139,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { ...@@ -137,7 +139,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
oldWindowSize := c.receiveWindowSize oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate() offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
if c.contributesToConnection { if c.contributesToConnection {
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
} }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment