favoritter/internal/middleware/auth_test.go
Ole-Morten Duesund a8f3aa6f7e test: add comprehensive test suite (44 → 169 tests) and v1.1 plan
Add 125 new test functions across 10 new test files, covering:
- CSRF middleware (8 tests): double-submit cookie validation
- Auth middleware (12 tests): SessionLoader, RequireAdmin, context helpers
- API handlers (28 tests): auth, faves CRUD, tags, users, export/import
- Web handlers (41 tests): signup, login, password reset, fave CRUD,
  admin panel, feeds, import/export, profiles, settings
- Config (8 tests): env parsing, defaults, trusted proxies, normalization
- Database (6 tests): migrations, PRAGMAs, idempotency, seeding
- Image processing (10 tests): JPEG/PNG, resize, EXIF strip, path traversal
- Render (6 tests): page/error/partial rendering, template functions
- Settings store (3 tests): CRUD operations
- Regression tests for display name fallback and CSP-safe autocomplete

Also adds CSRF middleware to testServer chain for end-to-end CSRF
verification, TESTPLAN.md documenting coverage, and PLANS-v1.1.md
with implementation plans for notes+OG, PWA, editing UX, and admin.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-04 00:18:01 +02:00

299 lines
8.6 KiB
Go

// SPDX-License-Identifier: AGPL-3.0-or-later
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"kode.naiv.no/olemd/favoritter/internal/database"
"kode.naiv.no/olemd/favoritter/internal/model"
"kode.naiv.no/olemd/favoritter/internal/store"
)
// testStores creates in-memory stores for auth middleware tests.
func testStores(t *testing.T) (*store.SessionStore, *store.UserStore) {
t.Helper()
db, err := database.Open(":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := database.Migrate(db); err != nil {
t.Fatalf("migrate: %v", err)
}
t.Cleanup(func() { db.Close() })
store.Argon2Memory = 1024
store.Argon2Time = 1
return store.NewSessionStore(db), store.NewUserStore(db)
}
func TestSessionLoaderValidToken(t *testing.T) {
sessions, users := testStores(t)
user, err := users.Create("testuser", "pass123", "user")
if err != nil {
t.Fatalf("create user: %v", err)
}
token, err := sessions.Create(user.ID)
if err != nil {
t.Fatalf("create session: %v", err)
}
var ctxUser *model.User
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxUser = UserFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
handler := SessionLoader(sessions, users)(inner)
req := httptest.NewRequest("GET", "/faves", nil)
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if ctxUser == nil {
t.Fatal("expected user in context, got nil")
}
if ctxUser.ID != user.ID {
t.Errorf("context user ID = %d, want %d", ctxUser.ID, user.ID)
}
}
func TestSessionLoaderInvalidToken(t *testing.T) {
sessions, users := testStores(t)
var ctxUser *model.User
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxUser = UserFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
handler := SessionLoader(sessions, users)(inner)
req := httptest.NewRequest("GET", "/faves", nil)
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "invalid-token"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("invalid token: got %d, want 200", rr.Code)
}
if ctxUser != nil {
t.Error("expected nil user for invalid token")
}
// Should clear the invalid session cookie.
for _, c := range rr.Result().Cookies() {
if c.Name == SessionCookieName && c.MaxAge == -1 {
return // Cookie cleared, good.
}
}
t.Error("expected session cookie to be cleared for invalid token")
}
func TestSessionLoaderNoCookie(t *testing.T) {
sessions, users := testStores(t)
var ctxUser *model.User
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxUser = UserFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
handler := SessionLoader(sessions, users)(inner)
req := httptest.NewRequest("GET", "/faves", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("no cookie: got %d, want 200", rr.Code)
}
if ctxUser != nil {
t.Error("expected nil user when no cookie")
}
}
func TestSessionLoaderSkipsStaticPaths(t *testing.T) {
sessions, users := testStores(t)
var handlerCalled bool
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
handler := SessionLoader(sessions, users)(inner)
for _, path := range []string{"/static/css/style.css", "/uploads/image.jpg"} {
handlerCalled = false
req := httptest.NewRequest("GET", path, nil)
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "some-token"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if !handlerCalled {
t.Errorf("handler not called for %s", path)
}
if rr.Code != http.StatusOK {
t.Errorf("%s: got %d, want 200", path, rr.Code)
}
}
}
func TestSessionLoaderDisabledUser(t *testing.T) {
sessions, users := testStores(t)
user, _ := users.Create("testuser", "pass123", "user")
token, _ := sessions.Create(user.ID)
users.SetDisabled(user.ID, true)
var ctxUser *model.User
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxUser = UserFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
handler := SessionLoader(sessions, users)(inner)
req := httptest.NewRequest("GET", "/faves", nil)
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if ctxUser != nil {
t.Error("disabled user should not be in context")
}
// Session should be deleted and cookie cleared.
for _, c := range rr.Result().Cookies() {
if c.Name == SessionCookieName && c.MaxAge == -1 {
return
}
}
t.Error("expected session cookie to be cleared for disabled user")
}
func TestRequireAdminRejectsNonAdmin(t *testing.T) {
handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Regular user in context.
user := &model.User{ID: 1, Username: "regular", Role: "user"}
ctx := context.WithValue(context.Background(), userKey, user)
req := httptest.NewRequest("GET", "/admin", nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("non-admin: got %d, want 403", rr.Code)
}
}
func TestRequireAdminAllowsAdmin(t *testing.T) {
handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
user := &model.User{ID: 1, Username: "admin", Role: "admin"}
ctx := context.WithValue(context.Background(), userKey, user)
req := httptest.NewRequest("GET", "/admin", nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("admin: got %d, want 200", rr.Code)
}
}
func TestRequireAdminNoUser(t *testing.T) {
handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/admin", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("no user: got %d, want 403", rr.Code)
}
}
func TestContextHelpers(t *testing.T) {
// UserFromContext with user.
user := &model.User{ID: 42, Username: "tester"}
ctx := context.WithValue(context.Background(), userKey, user)
got := UserFromContext(ctx)
if got == nil || got.ID != 42 {
t.Errorf("UserFromContext with user: got %v", got)
}
// UserFromContext without user.
if UserFromContext(context.Background()) != nil {
t.Error("UserFromContext should return nil for empty context")
}
// CSRFTokenFromContext.
ctx = context.WithValue(context.Background(), csrfTokenKey, "my-token")
if CSRFTokenFromContext(ctx) != "my-token" {
t.Error("CSRFTokenFromContext failed")
}
if CSRFTokenFromContext(context.Background()) != "" {
t.Error("CSRFTokenFromContext should return empty for missing key")
}
// RealIPFromContext.
ctx = context.WithValue(context.Background(), realIPKey, "1.2.3.4")
if RealIPFromContext(ctx) != "1.2.3.4" {
t.Error("RealIPFromContext failed")
}
if RealIPFromContext(context.Background()) != "" {
t.Error("RealIPFromContext should return empty for missing key")
}
}
func TestMustResetPasswordGuardRedirects(t *testing.T) {
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := MustResetPasswordGuard("")(inner)
// User with must_reset_password on a regular path → redirect.
user := &model.User{ID: 1, Username: "resetme", MustResetPassword: true}
ctx := context.WithValue(context.Background(), userKey, user)
req := httptest.NewRequest("GET", "/faves", nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusSeeOther {
t.Errorf("must-reset on /faves: got %d, want 303", rr.Code)
}
if loc := rr.Header().Get("Location"); loc != "/reset-password" {
t.Errorf("redirect location = %q, want /reset-password", loc)
}
}
func TestMustResetPasswordGuardAllowsPaths(t *testing.T) {
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := MustResetPasswordGuard("")(inner)
user := &model.User{ID: 1, Username: "resetme", MustResetPassword: true}
ctx := context.WithValue(context.Background(), userKey, user)
// These paths should pass through even with must_reset.
allowedPaths := []string{"/reset-password", "/logout", "/health", "/static/css/style.css"}
for _, path := range allowedPaths {
req := httptest.NewRequest("GET", path, nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("must-reset on %s: got %d, want 200", path, rr.Code)
}
}
}