Commit cf2808ae authored by Matt Holt's avatar Matt Holt

Merge pull request #75 from abiosoft/master

Rewrite: Support for Regular Expressions.
parents 1076daa8 c382c885
...@@ -18,23 +18,62 @@ func Rewrite(c *Controller) (middleware.Middleware, error) { ...@@ -18,23 +18,62 @@ func Rewrite(c *Controller) (middleware.Middleware, error) {
} }
func rewriteParse(c *Controller) ([]rewrite.Rule, error) { func rewriteParse(c *Controller) ([]rewrite.Rule, error) {
var rewrites []rewrite.Rule var simpleRules []rewrite.Rule
var regexpRules []rewrite.Rule
for c.Next() { for c.Next() {
var rule rewrite.Rule var rule rewrite.Rule
var err error
var base = "/"
var pattern, to string
var ext []string
if !c.NextArg() { args := c.RemainingArgs()
return rewrites, c.ArgErr()
}
rule.From = c.Val()
if !c.NextArg() { switch len(args) {
return rewrites, c.ArgErr() case 2:
rule = rewrite.NewSimpleRule(args[0], args[1])
simpleRules = append(simpleRules, rule)
case 1:
base = args[0]
fallthrough
case 0:
for c.NextBlock() {
switch c.Val() {
case "r", "regexp":
if !c.NextArg() {
return nil, c.ArgErr()
}
pattern = c.Val()
case "to":
if !c.NextArg() {
return nil, c.ArgErr()
}
to = c.Val()
case "ext":
args1 := c.RemainingArgs()
if len(args1) == 0 {
return nil, c.ArgErr()
}
ext = args1
default:
return nil, c.ArgErr()
}
}
// ensure pattern and to are specified
if pattern == "" || to == "" {
return nil, c.ArgErr()
}
if rule, err = rewrite.NewRegexpRule(base, pattern, to, ext); err != nil {
return nil, err
}
regexpRules = append(regexpRules, rule)
default:
return nil, c.ArgErr()
} }
rule.To = c.Val()
rewrites = append(rewrites, rule)
} }
return rewrites, nil // put simple rules in front to avoid regexp computation for them
return append(simpleRules, regexpRules...), nil
} }
...@@ -3,7 +3,9 @@ package setup ...@@ -3,7 +3,9 @@ package setup
import ( import (
"testing" "testing"
"fmt"
"github.com/mholt/caddy/middleware/rewrite" "github.com/mholt/caddy/middleware/rewrite"
"regexp"
) )
func TestRewrite(t *testing.T) { func TestRewrite(t *testing.T) {
...@@ -33,27 +35,27 @@ func TestRewrite(t *testing.T) { ...@@ -33,27 +35,27 @@ func TestRewrite(t *testing.T) {
} }
func TestRewriteParse(t *testing.T) { func TestRewriteParse(t *testing.T) {
tests := []struct { simpleTests := []struct {
input string input string
shouldErr bool shouldErr bool
expected []rewrite.Rule expected []rewrite.Rule
}{ }{
{`rewrite /from /to`, false, []rewrite.Rule{ {`rewrite /from /to`, false, []rewrite.Rule{
{From: "/from", To: "/to"}, rewrite.SimpleRule{"/from", "/to"},
}}, }},
{`rewrite /from /to {`rewrite /from /to
rewrite a b`, false, []rewrite.Rule{ rewrite a b`, false, []rewrite.Rule{
{From: "/from", To: "/to"}, rewrite.SimpleRule{"/from", "/to"},
{From: "a", To: "b"}, rewrite.SimpleRule{"a", "b"},
}}, }},
{`rewrite a`, true, []rewrite.Rule{}}, {`rewrite a`, true, []rewrite.Rule{}},
{`rewrite`, true, []rewrite.Rule{}}, {`rewrite`, true, []rewrite.Rule{}},
{`rewrite a b c`, true, []rewrite.Rule{ {`rewrite a b c`, true, []rewrite.Rule{
{From: "a", To: "b"}, rewrite.SimpleRule{"a", "b"},
}}, }},
} }
for i, test := range tests { for i, test := range simpleTests {
c := newTestController(test.input) c := newTestController(test.input)
actual, err := rewriteParse(c) actual, err := rewriteParse(c)
...@@ -61,6 +63,8 @@ func TestRewriteParse(t *testing.T) { ...@@ -61,6 +63,8 @@ func TestRewriteParse(t *testing.T) {
t.Errorf("Test %d didn't error, but it should have", i) t.Errorf("Test %d didn't error, but it should have", i)
} else if err != nil && !test.shouldErr { } else if err != nil && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
} else if err != nil && test.shouldErr {
continue
} }
if len(actual) != len(test.expected) { if len(actual) != len(test.expected) {
...@@ -68,8 +72,9 @@ func TestRewriteParse(t *testing.T) { ...@@ -68,8 +72,9 @@ func TestRewriteParse(t *testing.T) {
i, len(test.expected), len(actual)) i, len(test.expected), len(actual))
} }
for j, expectedRule := range test.expected { for j, e := range test.expected {
actualRule := actual[j] actualRule := actual[j].(rewrite.SimpleRule)
expectedRule := e.(rewrite.SimpleRule)
if actualRule.From != expectedRule.From { if actualRule.From != expectedRule.From {
t.Errorf("Test %d, rule %d: Expected From=%s, got %s", t.Errorf("Test %d, rule %d: Expected From=%s, got %s",
...@@ -82,4 +87,98 @@ func TestRewriteParse(t *testing.T) { ...@@ -82,4 +87,98 @@ func TestRewriteParse(t *testing.T) {
} }
} }
} }
regexpTests := []struct {
input string
shouldErr bool
expected []rewrite.Rule
}{
{`rewrite {
r .*
to /to
}`, false, []rewrite.Rule{
&rewrite.RegexpRule{"/", "/to", nil, regexp.MustCompile(".*")},
}},
{`rewrite {
regexp .*
to /to
ext / html txt
}`, false, []rewrite.Rule{
&rewrite.RegexpRule{"/", "/to", []string{"/", "html", "txt"}, regexp.MustCompile(".*")},
}},
{`rewrite /path {
r rr
to /dest
}
rewrite / {
regexp [a-z]+
to /to
}
`, false, []rewrite.Rule{
&rewrite.RegexpRule{"/path", "/dest", nil, regexp.MustCompile("rr")},
&rewrite.RegexpRule{"/", "/to", nil, regexp.MustCompile("[a-z]+")},
}},
{`rewrite {
to /to
}`, true, []rewrite.Rule{
&rewrite.RegexpRule{},
}},
{`rewrite {
r .*
}`, true, []rewrite.Rule{
&rewrite.RegexpRule{},
}},
{`rewrite {
}`, true, []rewrite.Rule{
&rewrite.RegexpRule{},
}},
{`rewrite /`, true, []rewrite.Rule{
&rewrite.RegexpRule{},
}},
}
for i, test := range regexpTests {
c := newTestController(test.input)
actual, err := rewriteParse(c)
if err == nil && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if err != nil && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
} else if err != nil && test.shouldErr {
continue
}
if len(actual) != len(test.expected) {
t.Fatalf("Test %d expected %d rules, but got %d",
i, len(test.expected), len(actual))
}
for j, e := range test.expected {
actualRule := actual[j].(*rewrite.RegexpRule)
expectedRule := e.(*rewrite.RegexpRule)
if actualRule.Base != expectedRule.Base {
t.Errorf("Test %d, rule %d: Expected Base=%s, got %s",
i, j, expectedRule.Base, actualRule.Base)
}
if actualRule.To != expectedRule.To {
t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
i, j, expectedRule.To, actualRule.To)
}
if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) {
t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v",
i, j, expectedRule.To, actualRule.To)
}
if actualRule.String() != expectedRule.String() {
t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
i, j, expectedRule.String(), actualRule.String())
}
}
}
} }
...@@ -5,7 +5,13 @@ package rewrite ...@@ -5,7 +5,13 @@ package rewrite
import ( import (
"net/http" "net/http"
"fmt"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
"net/url"
"path"
"path/filepath"
"regexp"
"strings"
) )
// Rewrite is middleware to rewrite request locations internally before being handled. // Rewrite is middleware to rewrite request locations internally before being handled.
...@@ -17,15 +23,171 @@ type Rewrite struct { ...@@ -17,15 +23,171 @@ type Rewrite struct {
// ServeHTTP implements the middleware.Handler interface. // ServeHTTP implements the middleware.Handler interface.
func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range rw.Rules { for _, rule := range rw.Rules {
if r.URL.Path == rule.From { if ok := rule.Rewrite(r); ok {
r.URL.Path = rule.To
break break
} }
} }
return rw.Next.ServeHTTP(w, r) return rw.Next.ServeHTTP(w, r)
} }
// A Rule describes an internal location rewrite rule. // Rule describes an internal location rewrite rule.
type Rule struct { type Rule interface {
// Rewrite rewrites the internal location of the current request.
Rewrite(*http.Request) bool
}
// SimpleRule is a simple rewrite rule.
type SimpleRule struct {
From, To string From, To string
} }
// NewSimpleRule creates a new Simple Rule
func NewSimpleRule(from, to string) SimpleRule {
return SimpleRule{from, to}
}
// Rewrite rewrites the internal location of the current request.
func (s SimpleRule) Rewrite(r *http.Request) bool {
if s.From == r.URL.Path {
r.URL.Path = s.To
return true
}
return false
}
// RegexpRule is a rewrite rule based on a regular expression
type RegexpRule struct {
// Path base. Request to this path and subpaths will be rewritten
Base string
// Path to rewrite to
To string
// Extensions to filter by
Exts []string
*regexp.Regexp
}
// NewRegexpRule creates a new RegexpRule. It returns an error if regexp
// pattern (pattern) or extensions (ext) are invalid.
func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) {
r, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
// validate extensions
for _, v := range ext {
if len(v) < 2 || (len(v) < 3 && v[0] == '!') {
// check if no extension is specified
if v != "/" && v != "!/" {
return nil, fmt.Errorf("Invalid extension %v", v)
}
}
}
return &RegexpRule{
base,
to,
ext,
r,
}, nil
}
// regexpVars are variables that can be used for To (rewrite destination path).
var regexpVars []string = []string{
"{path}",
"{query}",
"{file}",
"{dir}",
"{frag}",
}
// Rewrite rewrites the internal location of the current request.
func (r *RegexpRule) Rewrite(req *http.Request) bool {
rPath := req.URL.Path
// validate base
if !middleware.Path(rPath).Matches(r.Base) {
return false
}
// validate extensions
if !r.matchExt(rPath) {
return false
}
// validate regexp
if !r.MatchString(rPath[len(r.Base):]) {
return false
}
to := r.To
// check variables
for _, v := range regexpVars {
if strings.Contains(r.To, v) {
switch v {
case "{path}":
to = strings.Replace(to, v, req.URL.Path[1:], -1)
case "{query}":
to = strings.Replace(to, v, req.URL.RawQuery, -1)
case "{frag}":
to = strings.Replace(to, v, req.URL.Fragment, -1)
case "{file}":
_, file := path.Split(req.URL.Path)
to = strings.Replace(to, v, file, -1)
case "{dir}":
dir, _ := path.Split(req.URL.Path)
to = path.Clean(strings.Replace(to, v, dir, -1))
}
}
}
// validate resulting path
url, err := url.Parse(to)
if err != nil {
return false
}
// perform rewrite
req.URL.Path = url.Path
if url.RawQuery != "" {
// overwrite query string if present
req.URL.RawQuery = url.RawQuery
}
return true
}
// matchExt matches rPath against registered file extensions.
// Returns true if a match is found and false otherwise.
func (r *RegexpRule) matchExt(rPath string) bool {
f := filepath.Base(rPath)
ext := path.Ext(f)
if ext == "" {
ext = "/"
}
mustUse := false
for _, v := range r.Exts {
use := true
if v[0] == '!' {
use = false
v = v[1:]
}
if use {
mustUse = true
}
if ext == v {
return use
}
}
if mustUse {
return false
}
return true
}
...@@ -7,16 +7,41 @@ import ( ...@@ -7,16 +7,41 @@ import (
"testing" "testing"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
"strings"
) )
func TestRewrite(t *testing.T) { func TestRewrite(t *testing.T) {
rw := Rewrite{ rw := Rewrite{
Next: middleware.HandlerFunc(urlPrinter), Next: middleware.HandlerFunc(urlPrinter),
Rules: []Rule{ Rules: []Rule{
{From: "/from", To: "/to"}, NewSimpleRule("/from", "/to"),
{From: "/a", To: "/b"}, NewSimpleRule("/a", "/b"),
}, },
} }
regexpRules := [][]string{
[]string{"/reg/", ".*", "/to", ""},
[]string{"/r/", "[a-z]+", "/toaz", "!.html|"},
[]string{"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""},
[]string{"/ab/", "ab", "/ab?{query}", ".txt|"},
[]string{"/ab/", "ab", "/ab?type=html&{query}", ".html|"},
[]string{"/abc/", "ab", "/abc/{file}", ".html|"},
[]string{"/abcd/", "ab", "/a/{dir}/{file}", ".html|"},
[]string{"/abcde/", "ab", "/a#{frag}", ".html|"},
}
for _, regexpRule := range regexpRules {
var ext []string
if s := strings.Split(regexpRule[3], "|"); len(s) > 1 {
ext = s[:len(s)-1]
}
rule, err := NewRegexpRule(regexpRule[0], regexpRule[1], regexpRule[2], ext)
if err != nil {
t.Fatal(err)
}
rw.Rules = append(rw.Rules, rule)
}
tests := []struct { tests := []struct {
from string from string
expectedTo string expectedTo string
...@@ -29,6 +54,28 @@ func TestRewrite(t *testing.T) { ...@@ -29,6 +54,28 @@ func TestRewrite(t *testing.T) {
{"/asdf?foo=bar", "/asdf?foo=bar"}, {"/asdf?foo=bar", "/asdf?foo=bar"},
{"/foo#bar", "/foo#bar"}, {"/foo#bar", "/foo#bar"},
{"/a#foo", "/b#foo"}, {"/a#foo", "/b#foo"},
{"/reg/foo", "/to"},
{"/re", "/re"},
{"/r/", "/r/"},
{"/r/123", "/r/123"},
{"/r/a123", "/toaz"},
{"/r/abcz", "/toaz"},
{"/r/z", "/toaz"},
{"/r/z.html", "/r/z.html"},
{"/r/z.js", "/toaz"},
{"/url/asAB", "/to/url/asAB"},
{"/url/aBsAB", "/url/aBsAB"},
{"/url/a00sAB", "/to/url/a00sAB"},
{"/url/a0z0sAB", "/to/url/a0z0sAB"},
{"/ab/aa", "/ab/aa"},
{"/ab/ab", "/ab/ab"},
{"/ab/ab.txt", "/ab"},
{"/ab/ab.txt?name=name", "/ab?name=name"},
{"/ab/ab.html?name=name", "/ab?type=html&name=name"},
{"/abc/ab.html", "/abc/ab.html"},
{"/abcd/abcd.html", "/a/abcd/abcd.html"},
{"/abcde/abcde.html", "/a"},
{"/abcde/abcde.html#1234", "/a#1234"},
} }
for i, test := range tests { for i, test := range tests {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment