58 lines
1.4 KiB
Go
58 lines
1.4 KiB
Go
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
|
|
||
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"net/http"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/favoritter/internal/store"
|
||
|
|
)
|
||
|
|
|
||
|
|
const SessionCookieName = "session"
|
||
|
|
|
||
|
|
// ClearSessionCookie sets an expired session cookie to remove it from the client.
|
||
|
|
func ClearSessionCookie(w http.ResponseWriter) {
|
||
|
|
http.SetCookie(w, &http.Cookie{
|
||
|
|
Name: SessionCookieName,
|
||
|
|
Value: "",
|
||
|
|
Path: "/",
|
||
|
|
MaxAge: -1,
|
||
|
|
HttpOnly: true,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
// SessionLoader loads the user from the session cookie on every request.
|
||
|
|
// If the session is valid, the user is attached to the request context.
|
||
|
|
func SessionLoader(sessions *store.SessionStore, users *store.UserStore) func(http.Handler) http.Handler {
|
||
|
|
return func(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
cookie, err := r.Cookie(SessionCookieName)
|
||
|
|
if err != nil {
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
session, err := sessions.Validate(cookie.Value)
|
||
|
|
if err != nil {
|
||
|
|
if errors.Is(err, store.ErrSessionNotFound) {
|
||
|
|
ClearSessionCookie(w)
|
||
|
|
}
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
user, err := users.GetByID(session.UserID)
|
||
|
|
if err != nil || user.Disabled {
|
||
|
|
sessions.Delete(cookie.Value)
|
||
|
|
ClearSessionCookie(w)
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
ctx := context.WithValue(r.Context(), userKey, user)
|
||
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|