Skip to content

Commit

Permalink
Merge pull request #1 from ngrok-oss/add-trusted-origin-predicate-func
Browse files Browse the repository at this point in the history
adds TrustedOriginPredicateFunc option, tests
  • Loading branch information
cody-dot-js committed Aug 14, 2024
2 parents a009743 + 9baf069 commit 0b2a15d
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 15 deletions.
29 changes: 17 additions & 12 deletions csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,15 @@ type options struct {
Path string
// Note that the function and field names match the case of the associated
// http.Cookie field instead of the "correct" HTTPOnly name that golint suggests.
HttpOnly bool
Secure bool
SameSite SameSiteMode
RequestHeader string
FieldName string
ErrorHandler http.Handler
CookieName string
TrustedOrigins []string
HttpOnly bool
Secure bool
SameSite SameSiteMode
RequestHeader string
FieldName string
ErrorHandler http.Handler
CookieName string
TrustedOrigins []string
TrustedOriginPredicateFunc func(referer string) bool
}

// Protect is HTTP middleware that provides Cross-Site Request Forgery
Expand Down Expand Up @@ -258,10 +259,14 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
valid := sameOrigin(r.URL, referer)

if !valid {
for _, trustedOrigin := range cs.opts.TrustedOrigins {
if referer.Host == trustedOrigin {
valid = true
break
if cs.opts.TrustedOriginPredicateFunc != nil {
valid = cs.opts.TrustedOriginPredicateFunc(referer.Host)
} else {
for _, trustedOrigin := range cs.opts.TrustedOrigins {
if referer.Host == trustedOrigin {
valid = true
break
}
}
}
}
Expand Down
76 changes: 76 additions & 0 deletions csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,82 @@ func TestTrustedReferer(t *testing.T) {
}
}

// TestTrustedOriginPredicateFunc checks that HTTPS requests with a Referer that does not
// match the request URL correctly but is a trusted origin pass CSRF validation.
func TestTrustedOriginPredicateFunc(t *testing.T) {

testTable := []struct {
predicate func(referer string) bool
shouldPass bool
}{
{func(referer string) bool {
return referer == "golang.org"
}, true},
{func(referer string) bool {
return referer == "api.example.com" || referer == "golang.org"
}, true},
{func(referer string) bool {
return referer == "http://golang.org"
}, false},
{func(referer string) bool {
return referer == "https://golang.org"
}, false},
{func(referer string) bool {
return referer == "http://example.com"
}, false},
{func(referer string) bool {
return referer == "example.com"
}, false},
}

for _, item := range testTable {
s := http.NewServeMux()

p := Protect(testKey, TrustedOriginPredicateFunc(item.predicate))(s)

var token string
s.Handle("/", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
token = Token(r)
}))

// Obtain a CSRF cookie via a GET request.
r, err := http.NewRequest("GET", "https://www.gorillatoolkit.org/", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
p.ServeHTTP(rr, r)

// POST the token back in the header.
r, err = http.NewRequest("POST", "https://www.gorillatoolkit.org/", nil)
if err != nil {
t.Fatal(err)
}

setCookie(rr, r)
r.Header.Set("X-CSRF-Token", token)

// Set a non-matching Referer header.
r.Header.Set("Referer", "http://golang.org/")

rr = httptest.NewRecorder()
p.ServeHTTP(rr, r)

if item.shouldPass {
if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
rr.Code, http.StatusOK)
}
} else {
if rr.Code != http.StatusForbidden {
t.Fatalf("middleware failed reject a non-matching Referer header: got %v want %v",
rr.Code, http.StatusForbidden)
}
}
}
}

// Requests with a valid Referer should pass.
func TestWithReferer(t *testing.T) {
s := http.NewServeMux()
Expand Down
5 changes: 2 additions & 3 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ field.
gorilla/csrf is easy to use: add the middleware to individual handlers with
the below:
CSRF := csrf.Protect([]byte("32-byte-long-auth-key"))
http.HandlerFunc("/route", CSRF(YourHandler))
CSRF := csrf.Protect([]byte("32-byte-long-auth-key"))
http.HandlerFunc("/route", CSRF(YourHandler))
... and then collect the token with `csrf.Token(r)` before passing it to the
template, JSON body or HTTP header (you pick!). gorilla/csrf inspects the form body
Expand Down Expand Up @@ -171,6 +171,5 @@ important.
and the one-time-pad used for masking them.
This library does not seek to be adventurous.
*/
package csrf
15 changes: 15 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,21 @@ func TrustedOrigins(origins []string) Option {
}
}

// TrustedOriginPredicateFunc configures a predicate function that can be used
// to determine if a given Referer is trusted.
// Like TrustedOrigins, this will allow cross-domain CSRF use-cases - e.g. where
// the front-end is served from a different domain than the API server - to
// correctly pass a CSRF check.
// However, this function allows for more complex logic to be applied to determine
// if a Referer is trusted than strict equality string matching.
//
// You should only pass origins you own or have full control over.
func TrustedOriginPredicateFunc(f func(referer string) bool) Option {
return func(cs *csrf) {
cs.opts.TrustedOriginPredicateFunc = f
}
}

// setStore sets the store used by the CSRF middleware.
// Note: this is private (for now) to allow for internal API changes.
func setStore(s store) Option {
Expand Down

0 comments on commit 0b2a15d

Please sign in to comment.