// SPDX-License-Identifier: AGPL-3.0-or-later package middleware import ( "net/http" "net/http/httptest" "strings" "testing" "kode.naiv.no/olemd/favoritter/internal/config" ) func testCSRFHandler(cfg *config.Config) http.Handler { inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Echo the CSRF token from context so tests can verify it. token := CSRFTokenFromContext(r.Context()) w.Write([]byte(token)) }) return CSRFProtection(cfg)(inner) } func TestCSRFTokenSetInCookie(t *testing.T) { handler := testCSRFHandler(&config.Config{}) req := httptest.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatalf("GET: got %d, want 200", rr.Code) } // Should set a csrf_token cookie. var found bool for _, c := range rr.Result().Cookies() { if c.Name == "csrf_token" { found = true if c.Value == "" { t.Error("csrf_token cookie is empty") } if c.HttpOnly { t.Error("csrf_token cookie should not be HttpOnly (JS needs to read it)") } } } if !found { t.Error("csrf_token cookie not set on first request") } // The context token should match the cookie. body := rr.Body.String() if body == "" { t.Error("CSRF token not set in context") } } func TestCSRFValidTokenAccepted(t *testing.T) { handler := testCSRFHandler(&config.Config{}) // First GET to obtain a token. getReq := httptest.NewRequest("GET", "/", nil) getRR := httptest.NewRecorder() handler.ServeHTTP(getRR, getReq) var token string for _, c := range getRR.Result().Cookies() { if c.Name == "csrf_token" { token = c.Value } } if token == "" { t.Fatal("no csrf_token cookie from GET") } // POST with matching cookie + form field. form := "csrf_token=" + token req := httptest.NewRequest("POST", "/submit", strings.NewReader(form)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(&http.Cookie{Name: "csrf_token", Value: token}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("valid CSRF POST: got %d, want 200", rr.Code) } } func TestCSRFMismatchRejected(t *testing.T) { handler := testCSRFHandler(&config.Config{}) form := "csrf_token=wrong-token" req := httptest.NewRequest("POST", "/submit", strings.NewReader(form)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(&http.Cookie{Name: "csrf_token", Value: "real-token"}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("mismatched CSRF: got %d, want 403", rr.Code) } } func TestCSRFMissingTokenRejected(t *testing.T) { handler := testCSRFHandler(&config.Config{}) // POST with cookie but no form field or header. req := httptest.NewRequest("POST", "/submit", nil) req.AddCookie(&http.Cookie{Name: "csrf_token", Value: "some-token"}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("missing CSRF form field: got %d, want 403", rr.Code) } } func TestCSRFHeaderFallback(t *testing.T) { handler := testCSRFHandler(&config.Config{}) token := "valid-header-token" req := httptest.NewRequest("POST", "/submit", nil) req.AddCookie(&http.Cookie{Name: "csrf_token", Value: token}) req.Header.Set("X-CSRF-Token", token) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("CSRF via header: got %d, want 200", rr.Code) } } func TestCSRFSkippedForAPI(t *testing.T) { handler := testCSRFHandler(&config.Config{}) // POST to /api/ path — should skip CSRF validation. req := httptest.NewRequest("POST", "/api/v1/faves", strings.NewReader(`{}`)) req.AddCookie(&http.Cookie{Name: "csrf_token", Value: "some-token"}) // Intentionally no CSRF form field or header. rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("API route CSRF skip: got %d, want 200", rr.Code) } } func TestCSRFSafeMethodsPassThrough(t *testing.T) { handler := testCSRFHandler(&config.Config{}) for _, method := range []string{"GET", "HEAD", "OPTIONS"} { req := httptest.NewRequest(method, "/page", nil) // No CSRF cookie or token at all. rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("%s without CSRF: got %d, want 200", method, rr.Code) } } } func TestCSRFExistingCookieReused(t *testing.T) { handler := testCSRFHandler(&config.Config{}) // Send a request with an existing csrf_token cookie. existingToken := "pre-existing-token-value" req := httptest.NewRequest("GET", "/", nil) req.AddCookie(&http.Cookie{Name: "csrf_token", Value: existingToken}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) // The context token should be the existing one. body := rr.Body.String() if body != existingToken { t.Errorf("context token = %q, want existing %q", body, existingToken) } // Should NOT set a new cookie (existing one is reused). for _, c := range rr.Result().Cookies() { if c.Name == "csrf_token" { t.Error("should not set new csrf_token cookie when one already exists") } } }