185 lines
5.1 KiB
Go
185 lines
5.1 KiB
Go
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
|
|
||
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/favoritter/internal/config"
|
||
|
|
)
|
||
|
|
|
||
|
|
func testCSRFHandler(cfg *config.Config) http.Handler {
|
||
|
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
// Echo the CSRF token from context so tests can verify it.
|
||
|
|
token := CSRFTokenFromContext(r.Context())
|
||
|
|
w.Write([]byte(token))
|
||
|
|
})
|
||
|
|
return CSRFProtection(cfg)(inner)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCSRFTokenSetInCookie(t *testing.T) {
|
||
|
|
handler := testCSRFHandler(&config.Config{})
|
||
|
|
|
||
|
|
req := httptest.NewRequest("GET", "/", nil)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Fatalf("GET: got %d, want 200", rr.Code)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Should set a csrf_token cookie.
|
||
|
|
var found bool
|
||
|
|
for _, c := range rr.Result().Cookies() {
|
||
|
|
if c.Name == "csrf_token" {
|
||
|
|
found = true
|
||
|
|
if c.Value == "" {
|
||
|
|
t.Error("csrf_token cookie is empty")
|
||
|
|
}
|
||
|
|
if c.HttpOnly {
|
||
|
|
t.Error("csrf_token cookie should not be HttpOnly (JS needs to read it)")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if !found {
|
||
|
|
t.Error("csrf_token cookie not set on first request")
|
||
|
|
}
|
||
|
|
|
||
|
|
// The context token should match the cookie.
|
||
|
|
body := rr.Body.String()
|
||
|
|
if body == "" {
|
||
|
|
t.Error("CSRF token not set in context")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCSRFValidTokenAccepted(t *testing.T) {
|
||
|
|
handler := testCSRFHandler(&config.Config{})
|
||
|
|
|
||
|
|
// First GET to obtain a token.
|
||
|
|
getReq := httptest.NewRequest("GET", "/", nil)
|
||
|
|
getRR := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(getRR, getReq)
|
||
|
|
|
||
|
|
var token string
|
||
|
|
for _, c := range getRR.Result().Cookies() {
|
||
|
|
if c.Name == "csrf_token" {
|
||
|
|
token = c.Value
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if token == "" {
|
||
|
|
t.Fatal("no csrf_token cookie from GET")
|
||
|
|
}
|
||
|
|
|
||
|
|
// POST with matching cookie + form field.
|
||
|
|
form := "csrf_token=" + token
|
||
|
|
req := httptest.NewRequest("POST", "/submit", strings.NewReader(form))
|
||
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||
|
|
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: token})
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Errorf("valid CSRF POST: got %d, want 200", rr.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCSRFMismatchRejected(t *testing.T) {
|
||
|
|
handler := testCSRFHandler(&config.Config{})
|
||
|
|
|
||
|
|
form := "csrf_token=wrong-token"
|
||
|
|
req := httptest.NewRequest("POST", "/submit", strings.NewReader(form))
|
||
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||
|
|
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: "real-token"})
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusForbidden {
|
||
|
|
t.Errorf("mismatched CSRF: got %d, want 403", rr.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCSRFMissingTokenRejected(t *testing.T) {
|
||
|
|
handler := testCSRFHandler(&config.Config{})
|
||
|
|
|
||
|
|
// POST with cookie but no form field or header.
|
||
|
|
req := httptest.NewRequest("POST", "/submit", nil)
|
||
|
|
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: "some-token"})
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusForbidden {
|
||
|
|
t.Errorf("missing CSRF form field: got %d, want 403", rr.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCSRFHeaderFallback(t *testing.T) {
|
||
|
|
handler := testCSRFHandler(&config.Config{})
|
||
|
|
|
||
|
|
token := "valid-header-token"
|
||
|
|
req := httptest.NewRequest("POST", "/submit", nil)
|
||
|
|
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: token})
|
||
|
|
req.Header.Set("X-CSRF-Token", token)
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Errorf("CSRF via header: got %d, want 200", rr.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCSRFSkippedForAPI(t *testing.T) {
|
||
|
|
handler := testCSRFHandler(&config.Config{})
|
||
|
|
|
||
|
|
// POST to /api/ path — should skip CSRF validation.
|
||
|
|
req := httptest.NewRequest("POST", "/api/v1/faves", strings.NewReader(`{}`))
|
||
|
|
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: "some-token"})
|
||
|
|
// Intentionally no CSRF form field or header.
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Errorf("API route CSRF skip: got %d, want 200", rr.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCSRFSafeMethodsPassThrough(t *testing.T) {
|
||
|
|
handler := testCSRFHandler(&config.Config{})
|
||
|
|
|
||
|
|
for _, method := range []string{"GET", "HEAD", "OPTIONS"} {
|
||
|
|
req := httptest.NewRequest(method, "/page", nil)
|
||
|
|
// No CSRF cookie or token at all.
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
if rr.Code != http.StatusOK {
|
||
|
|
t.Errorf("%s without CSRF: got %d, want 200", method, rr.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCSRFExistingCookieReused(t *testing.T) {
|
||
|
|
handler := testCSRFHandler(&config.Config{})
|
||
|
|
|
||
|
|
// Send a request with an existing csrf_token cookie.
|
||
|
|
existingToken := "pre-existing-token-value"
|
||
|
|
req := httptest.NewRequest("GET", "/", nil)
|
||
|
|
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: existingToken})
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
// The context token should be the existing one.
|
||
|
|
body := rr.Body.String()
|
||
|
|
if body != existingToken {
|
||
|
|
t.Errorf("context token = %q, want existing %q", body, existingToken)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Should NOT set a new cookie (existing one is reused).
|
||
|
|
for _, c := range rr.Result().Cookies() {
|
||
|
|
if c.Name == "csrf_token" {
|
||
|
|
t.Error("should not set new csrf_token cookie when one already exists")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|