61 lines
1.6 KiB
Go
61 lines
1.6 KiB
Go
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
|
|
||
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"net/http"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/favoritter/internal/model"
|
||
|
|
)
|
||
|
|
|
||
|
|
type contextKey string
|
||
|
|
|
||
|
|
const (
|
||
|
|
userKey contextKey = "user"
|
||
|
|
csrfTokenKey contextKey = "csrf_token"
|
||
|
|
realIPKey contextKey = "real_ip"
|
||
|
|
)
|
||
|
|
|
||
|
|
// UserFromContext returns the authenticated user from the request context, or nil.
|
||
|
|
func UserFromContext(ctx context.Context) *model.User {
|
||
|
|
u, _ := ctx.Value(userKey).(*model.User)
|
||
|
|
return u
|
||
|
|
}
|
||
|
|
|
||
|
|
// CSRFTokenFromContext returns the CSRF token from the request context.
|
||
|
|
func CSRFTokenFromContext(ctx context.Context) string {
|
||
|
|
s, _ := ctx.Value(csrfTokenKey).(string)
|
||
|
|
return s
|
||
|
|
}
|
||
|
|
|
||
|
|
// RealIPFromContext returns the real client IP from the request context.
|
||
|
|
func RealIPFromContext(ctx context.Context) string {
|
||
|
|
s, _ := ctx.Value(realIPKey).(string)
|
||
|
|
return s
|
||
|
|
}
|
||
|
|
|
||
|
|
// RequireLogin redirects to the login page if no user is authenticated.
|
||
|
|
func RequireLogin(basePath string) func(http.Handler) http.Handler {
|
||
|
|
return func(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
if UserFromContext(r.Context()) == nil {
|
||
|
|
http.Redirect(w, r, basePath+"/login", http.StatusSeeOther)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// RequireAdmin returns 403 if the user is not an admin.
|
||
|
|
func RequireAdmin(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
user := UserFromContext(r.Context())
|
||
|
|
if user == nil || !user.IsAdmin() {
|
||
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
})
|
||
|
|
}
|