// SPDX-License-Identifier: AGPL-3.0-or-later package middleware import ( "context" "errors" "net/http" "strings" "kode.naiv.no/olemd/favoritter/internal/store" ) const SessionCookieName = "session" // ClearSessionCookie sets an expired session cookie to remove it from the client. func ClearSessionCookie(w http.ResponseWriter) { http.SetCookie(w, &http.Cookie{ Name: SessionCookieName, Value: "", Path: "/", MaxAge: -1, HttpOnly: true, }) } // SessionLoader loads the user from the session cookie on every request. // If the session is valid, the user is attached to the request context. func SessionLoader(sessions *store.SessionStore, users *store.UserStore) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip session lookup for static assets and uploads — they // never use the user context and this avoids 2 DB queries // per asset per page load. if strings.HasPrefix(r.URL.Path, "/static/") || strings.HasPrefix(r.URL.Path, "/uploads/") { next.ServeHTTP(w, r) return } cookie, err := r.Cookie(SessionCookieName) if err != nil { next.ServeHTTP(w, r) return } session, err := sessions.Validate(cookie.Value) if err != nil { if errors.Is(err, store.ErrSessionNotFound) { ClearSessionCookie(w) } next.ServeHTTP(w, r) return } user, err := users.GetByID(session.UserID) if err != nil || user.Disabled { sessions.Delete(cookie.Value) ClearSessionCookie(w) next.ServeHTTP(w, r) return } ctx := context.WithValue(r.Context(), userKey, user) next.ServeHTTP(w, r.WithContext(ctx)) }) } }