// SPDX-License-Identifier: AGPL-3.0-or-later package middleware import ( "context" "net/http" "net/http/httptest" "testing" "kode.naiv.no/olemd/favoritter/internal/database" "kode.naiv.no/olemd/favoritter/internal/model" "kode.naiv.no/olemd/favoritter/internal/store" ) // testStores creates in-memory stores for auth middleware tests. func testStores(t *testing.T) (*store.SessionStore, *store.UserStore) { t.Helper() db, err := database.Open(":memory:") if err != nil { t.Fatalf("open db: %v", err) } if err := database.Migrate(db); err != nil { t.Fatalf("migrate: %v", err) } t.Cleanup(func() { db.Close() }) store.Argon2Memory = 1024 store.Argon2Time = 1 return store.NewSessionStore(db), store.NewUserStore(db) } func TestSessionLoaderValidToken(t *testing.T) { sessions, users := testStores(t) user, err := users.Create("testuser", "pass123", "user") if err != nil { t.Fatalf("create user: %v", err) } token, err := sessions.Create(user.ID) if err != nil { t.Fatalf("create session: %v", err) } var ctxUser *model.User inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctxUser = UserFromContext(r.Context()) w.WriteHeader(http.StatusOK) }) handler := SessionLoader(sessions, users)(inner) req := httptest.NewRequest("GET", "/faves", nil) req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if ctxUser == nil { t.Fatal("expected user in context, got nil") } if ctxUser.ID != user.ID { t.Errorf("context user ID = %d, want %d", ctxUser.ID, user.ID) } } func TestSessionLoaderInvalidToken(t *testing.T) { sessions, users := testStores(t) var ctxUser *model.User inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctxUser = UserFromContext(r.Context()) w.WriteHeader(http.StatusOK) }) handler := SessionLoader(sessions, users)(inner) req := httptest.NewRequest("GET", "/faves", nil) req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "invalid-token"}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("invalid token: got %d, want 200", rr.Code) } if ctxUser != nil { t.Error("expected nil user for invalid token") } // Should clear the invalid session cookie. for _, c := range rr.Result().Cookies() { if c.Name == SessionCookieName && c.MaxAge == -1 { return // Cookie cleared, good. } } t.Error("expected session cookie to be cleared for invalid token") } func TestSessionLoaderNoCookie(t *testing.T) { sessions, users := testStores(t) var ctxUser *model.User inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctxUser = UserFromContext(r.Context()) w.WriteHeader(http.StatusOK) }) handler := SessionLoader(sessions, users)(inner) req := httptest.NewRequest("GET", "/faves", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("no cookie: got %d, want 200", rr.Code) } if ctxUser != nil { t.Error("expected nil user when no cookie") } } func TestSessionLoaderSkipsStaticPaths(t *testing.T) { sessions, users := testStores(t) var handlerCalled bool inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerCalled = true w.WriteHeader(http.StatusOK) }) handler := SessionLoader(sessions, users)(inner) for _, path := range []string{"/static/css/style.css", "/uploads/image.jpg"} { handlerCalled = false req := httptest.NewRequest("GET", path, nil) req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "some-token"}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if !handlerCalled { t.Errorf("handler not called for %s", path) } if rr.Code != http.StatusOK { t.Errorf("%s: got %d, want 200", path, rr.Code) } } } func TestSessionLoaderDisabledUser(t *testing.T) { sessions, users := testStores(t) user, _ := users.Create("testuser", "pass123", "user") token, _ := sessions.Create(user.ID) users.SetDisabled(user.ID, true) var ctxUser *model.User inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctxUser = UserFromContext(r.Context()) w.WriteHeader(http.StatusOK) }) handler := SessionLoader(sessions, users)(inner) req := httptest.NewRequest("GET", "/faves", nil) req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if ctxUser != nil { t.Error("disabled user should not be in context") } // Session should be deleted and cookie cleared. for _, c := range rr.Result().Cookies() { if c.Name == SessionCookieName && c.MaxAge == -1 { return } } t.Error("expected session cookie to be cleared for disabled user") } func TestRequireAdminRejectsNonAdmin(t *testing.T) { handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Regular user in context. user := &model.User{ID: 1, Username: "regular", Role: "user"} ctx := context.WithValue(context.Background(), userKey, user) req := httptest.NewRequest("GET", "/admin", nil).WithContext(ctx) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("non-admin: got %d, want 403", rr.Code) } } func TestRequireAdminAllowsAdmin(t *testing.T) { handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) user := &model.User{ID: 1, Username: "admin", Role: "admin"} ctx := context.WithValue(context.Background(), userKey, user) req := httptest.NewRequest("GET", "/admin", nil).WithContext(ctx) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("admin: got %d, want 200", rr.Code) } } func TestRequireAdminNoUser(t *testing.T) { handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/admin", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("no user: got %d, want 403", rr.Code) } } func TestContextHelpers(t *testing.T) { // UserFromContext with user. user := &model.User{ID: 42, Username: "tester"} ctx := context.WithValue(context.Background(), userKey, user) got := UserFromContext(ctx) if got == nil || got.ID != 42 { t.Errorf("UserFromContext with user: got %v", got) } // UserFromContext without user. if UserFromContext(context.Background()) != nil { t.Error("UserFromContext should return nil for empty context") } // CSRFTokenFromContext. ctx = context.WithValue(context.Background(), csrfTokenKey, "my-token") if CSRFTokenFromContext(ctx) != "my-token" { t.Error("CSRFTokenFromContext failed") } if CSRFTokenFromContext(context.Background()) != "" { t.Error("CSRFTokenFromContext should return empty for missing key") } // RealIPFromContext. ctx = context.WithValue(context.Background(), realIPKey, "1.2.3.4") if RealIPFromContext(ctx) != "1.2.3.4" { t.Error("RealIPFromContext failed") } if RealIPFromContext(context.Background()) != "" { t.Error("RealIPFromContext should return empty for missing key") } } func TestMustResetPasswordGuardRedirects(t *testing.T) { inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) handler := MustResetPasswordGuard("")(inner) // User with must_reset_password on a regular path → redirect. user := &model.User{ID: 1, Username: "resetme", MustResetPassword: true} ctx := context.WithValue(context.Background(), userKey, user) req := httptest.NewRequest("GET", "/faves", nil).WithContext(ctx) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusSeeOther { t.Errorf("must-reset on /faves: got %d, want 303", rr.Code) } if loc := rr.Header().Get("Location"); loc != "/reset-password" { t.Errorf("redirect location = %q, want /reset-password", loc) } } func TestMustResetPasswordGuardAllowsPaths(t *testing.T) { inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) handler := MustResetPasswordGuard("")(inner) user := &model.User{ID: 1, Username: "resetme", MustResetPassword: true} ctx := context.WithValue(context.Background(), userKey, user) // These paths should pass through even with must_reset. allowedPaths := []string{"/reset-password", "/logout", "/health", "/static/css/style.css"} for _, path := range allowedPaths { req := httptest.NewRequest("GET", path, nil).WithContext(ctx) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("must-reset on %s: got %d, want 200", path, rr.Code) } } }