93 lines
2.4 KiB
Go
93 lines
2.4 KiB
Go
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
|
|
||
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"crypto/rand"
|
||
|
|
"encoding/hex"
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/favoritter/internal/config"
|
||
|
|
)
|
||
|
|
|
||
|
|
const (
|
||
|
|
csrfCookieName = "csrf_token"
|
||
|
|
csrfFormField = "csrf_token"
|
||
|
|
csrfHeaderName = "X-CSRF-Token"
|
||
|
|
)
|
||
|
|
|
||
|
|
// CSRFProtection implements double-submit cookie pattern for CSRF prevention.
|
||
|
|
// A token is stored in a cookie and must also be submitted in a form field
|
||
|
|
// or header on state-changing requests (POST, PUT, DELETE, PATCH).
|
||
|
|
func CSRFProtection(cfg *config.Config) func(http.Handler) http.Handler {
|
||
|
|
return func(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
// Read or generate the CSRF token.
|
||
|
|
token := ""
|
||
|
|
if cookie, err := r.Cookie(csrfCookieName); err == nil {
|
||
|
|
token = cookie.Value
|
||
|
|
}
|
||
|
|
if token == "" {
|
||
|
|
token = generateCSRFToken()
|
||
|
|
secure := IsSecureRequest(r, cfg)
|
||
|
|
http.SetCookie(w, &http.Cookie{
|
||
|
|
Name: csrfCookieName,
|
||
|
|
Value: token,
|
||
|
|
Path: "/",
|
||
|
|
HttpOnly: false, // JS needs to read it for HTMX hx-headers
|
||
|
|
Secure: secure,
|
||
|
|
SameSite: http.SameSiteLaxMode,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
// Attach token to context for templates.
|
||
|
|
ctx := context.WithValue(r.Context(), csrfTokenKey, token)
|
||
|
|
r = r.WithContext(ctx)
|
||
|
|
|
||
|
|
// Validate on state-changing methods.
|
||
|
|
if isStateChangingMethod(r.Method) {
|
||
|
|
// Skip CSRF check for API routes that use Bearer auth (future).
|
||
|
|
if !strings.HasPrefix(r.URL.Path, "/api/") {
|
||
|
|
submitted := r.FormValue(csrfFormField)
|
||
|
|
if submitted == "" {
|
||
|
|
submitted = r.Header.Get(csrfHeaderName)
|
||
|
|
}
|
||
|
|
if submitted != token {
|
||
|
|
http.Error(w, "CSRF token mismatch", http.StatusForbidden)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func isStateChangingMethod(method string) bool {
|
||
|
|
switch method {
|
||
|
|
case http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch:
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
func generateCSRFToken() string {
|
||
|
|
b := make([]byte, 32)
|
||
|
|
rand.Read(b)
|
||
|
|
return hex.EncodeToString(b)
|
||
|
|
}
|
||
|
|
|
||
|
|
// IsSecureRequest determines if the original client request used HTTPS,
|
||
|
|
// checking X-Forwarded-Proto from trusted proxies.
|
||
|
|
func IsSecureRequest(r *http.Request, cfg *config.Config) bool {
|
||
|
|
if cfg.ExternalURL != "" {
|
||
|
|
return strings.HasPrefix(cfg.ExternalURL, "https://")
|
||
|
|
}
|
||
|
|
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||
|
|
return proto == "https"
|
||
|
|
}
|
||
|
|
return r.TLS != nil
|
||
|
|
}
|