// SPDX-License-Identifier: AGPL-3.0-or-later package middleware import ( "context" "crypto/rand" "encoding/hex" "net/http" "strings" "kode.naiv.no/olemd/favoritter/internal/config" ) const ( csrfCookieName = "csrf_token" csrfFormField = "csrf_token" csrfHeaderName = "X-CSRF-Token" ) // CSRFProtection implements double-submit cookie pattern for CSRF prevention. // A token is stored in a cookie and must also be submitted in a form field // or header on state-changing requests (POST, PUT, DELETE, PATCH). func CSRFProtection(cfg *config.Config) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Read or generate the CSRF token. token := "" if cookie, err := r.Cookie(csrfCookieName); err == nil { token = cookie.Value } if token == "" { token = generateCSRFToken() secure := IsSecureRequest(r, cfg) http.SetCookie(w, &http.Cookie{ Name: csrfCookieName, Value: token, Path: "/", HttpOnly: false, // JS needs to read it for HTMX hx-headers Secure: secure, SameSite: http.SameSiteLaxMode, }) } // Attach token to context for templates. ctx := context.WithValue(r.Context(), csrfTokenKey, token) r = r.WithContext(ctx) // Validate on state-changing methods. if isStateChangingMethod(r.Method) { // Skip CSRF check for API routes that use Bearer auth (future). if !strings.HasPrefix(r.URL.Path, "/api/") { submitted := r.FormValue(csrfFormField) if submitted == "" { submitted = r.Header.Get(csrfHeaderName) } if submitted != token { http.Error(w, "CSRF token mismatch", http.StatusForbidden) return } } } next.ServeHTTP(w, r) }) } } func isStateChangingMethod(method string) bool { switch method { case http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch: return true } return false } func generateCSRFToken() string { b := make([]byte, 32) rand.Read(b) return hex.EncodeToString(b) } // IsSecureRequest determines if the original client request used HTTPS, // checking X-Forwarded-Proto from trusted proxies. func IsSecureRequest(r *http.Request, cfg *config.Config) bool { if cfg.ExternalURL != "" { return strings.HasPrefix(cfg.ExternalURL, "https://") } if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { return proto == "https" } return r.TLS != nil }