Add 125 new test functions across 10 new test files, covering: - CSRF middleware (8 tests): double-submit cookie validation - Auth middleware (12 tests): SessionLoader, RequireAdmin, context helpers - API handlers (28 tests): auth, faves CRUD, tags, users, export/import - Web handlers (41 tests): signup, login, password reset, fave CRUD, admin panel, feeds, import/export, profiles, settings - Config (8 tests): env parsing, defaults, trusted proxies, normalization - Database (6 tests): migrations, PRAGMAs, idempotency, seeding - Image processing (10 tests): JPEG/PNG, resize, EXIF strip, path traversal - Render (6 tests): page/error/partial rendering, template functions - Settings store (3 tests): CRUD operations - Regression tests for display name fallback and CSP-safe autocomplete Also adds CSRF middleware to testServer chain for end-to-end CSRF verification, TESTPLAN.md documenting coverage, and PLANS-v1.1.md with implementation plans for notes+OG, PWA, editing UX, and admin. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
299 lines
8.6 KiB
Go
299 lines
8.6 KiB
Go
// SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"kode.naiv.no/olemd/favoritter/internal/database"
|
|
"kode.naiv.no/olemd/favoritter/internal/model"
|
|
"kode.naiv.no/olemd/favoritter/internal/store"
|
|
)
|
|
|
|
// testStores creates in-memory stores for auth middleware tests.
|
|
func testStores(t *testing.T) (*store.SessionStore, *store.UserStore) {
|
|
t.Helper()
|
|
db, err := database.Open(":memory:")
|
|
if err != nil {
|
|
t.Fatalf("open db: %v", err)
|
|
}
|
|
if err := database.Migrate(db); err != nil {
|
|
t.Fatalf("migrate: %v", err)
|
|
}
|
|
t.Cleanup(func() { db.Close() })
|
|
|
|
store.Argon2Memory = 1024
|
|
store.Argon2Time = 1
|
|
|
|
return store.NewSessionStore(db), store.NewUserStore(db)
|
|
}
|
|
|
|
func TestSessionLoaderValidToken(t *testing.T) {
|
|
sessions, users := testStores(t)
|
|
|
|
user, err := users.Create("testuser", "pass123", "user")
|
|
if err != nil {
|
|
t.Fatalf("create user: %v", err)
|
|
}
|
|
token, err := sessions.Create(user.ID)
|
|
if err != nil {
|
|
t.Fatalf("create session: %v", err)
|
|
}
|
|
|
|
var ctxUser *model.User
|
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctxUser = UserFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
handler := SessionLoader(sessions, users)(inner)
|
|
req := httptest.NewRequest("GET", "/faves", nil)
|
|
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token})
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if ctxUser == nil {
|
|
t.Fatal("expected user in context, got nil")
|
|
}
|
|
if ctxUser.ID != user.ID {
|
|
t.Errorf("context user ID = %d, want %d", ctxUser.ID, user.ID)
|
|
}
|
|
}
|
|
|
|
func TestSessionLoaderInvalidToken(t *testing.T) {
|
|
sessions, users := testStores(t)
|
|
|
|
var ctxUser *model.User
|
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctxUser = UserFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
handler := SessionLoader(sessions, users)(inner)
|
|
req := httptest.NewRequest("GET", "/faves", nil)
|
|
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "invalid-token"})
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("invalid token: got %d, want 200", rr.Code)
|
|
}
|
|
if ctxUser != nil {
|
|
t.Error("expected nil user for invalid token")
|
|
}
|
|
|
|
// Should clear the invalid session cookie.
|
|
for _, c := range rr.Result().Cookies() {
|
|
if c.Name == SessionCookieName && c.MaxAge == -1 {
|
|
return // Cookie cleared, good.
|
|
}
|
|
}
|
|
t.Error("expected session cookie to be cleared for invalid token")
|
|
}
|
|
|
|
func TestSessionLoaderNoCookie(t *testing.T) {
|
|
sessions, users := testStores(t)
|
|
|
|
var ctxUser *model.User
|
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctxUser = UserFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
handler := SessionLoader(sessions, users)(inner)
|
|
req := httptest.NewRequest("GET", "/faves", nil)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("no cookie: got %d, want 200", rr.Code)
|
|
}
|
|
if ctxUser != nil {
|
|
t.Error("expected nil user when no cookie")
|
|
}
|
|
}
|
|
|
|
func TestSessionLoaderSkipsStaticPaths(t *testing.T) {
|
|
sessions, users := testStores(t)
|
|
|
|
var handlerCalled bool
|
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
handler := SessionLoader(sessions, users)(inner)
|
|
|
|
for _, path := range []string{"/static/css/style.css", "/uploads/image.jpg"} {
|
|
handlerCalled = false
|
|
req := httptest.NewRequest("GET", path, nil)
|
|
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "some-token"})
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if !handlerCalled {
|
|
t.Errorf("handler not called for %s", path)
|
|
}
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("%s: got %d, want 200", path, rr.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSessionLoaderDisabledUser(t *testing.T) {
|
|
sessions, users := testStores(t)
|
|
|
|
user, _ := users.Create("testuser", "pass123", "user")
|
|
token, _ := sessions.Create(user.ID)
|
|
users.SetDisabled(user.ID, true)
|
|
|
|
var ctxUser *model.User
|
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctxUser = UserFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
handler := SessionLoader(sessions, users)(inner)
|
|
req := httptest.NewRequest("GET", "/faves", nil)
|
|
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token})
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if ctxUser != nil {
|
|
t.Error("disabled user should not be in context")
|
|
}
|
|
// Session should be deleted and cookie cleared.
|
|
for _, c := range rr.Result().Cookies() {
|
|
if c.Name == SessionCookieName && c.MaxAge == -1 {
|
|
return
|
|
}
|
|
}
|
|
t.Error("expected session cookie to be cleared for disabled user")
|
|
}
|
|
|
|
func TestRequireAdminRejectsNonAdmin(t *testing.T) {
|
|
handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Regular user in context.
|
|
user := &model.User{ID: 1, Username: "regular", Role: "user"}
|
|
ctx := context.WithValue(context.Background(), userKey, user)
|
|
req := httptest.NewRequest("GET", "/admin", nil).WithContext(ctx)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Errorf("non-admin: got %d, want 403", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAdminAllowsAdmin(t *testing.T) {
|
|
handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
user := &model.User{ID: 1, Username: "admin", Role: "admin"}
|
|
ctx := context.WithValue(context.Background(), userKey, user)
|
|
req := httptest.NewRequest("GET", "/admin", nil).WithContext(ctx)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("admin: got %d, want 200", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAdminNoUser(t *testing.T) {
|
|
handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest("GET", "/admin", nil)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Errorf("no user: got %d, want 403", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestContextHelpers(t *testing.T) {
|
|
// UserFromContext with user.
|
|
user := &model.User{ID: 42, Username: "tester"}
|
|
ctx := context.WithValue(context.Background(), userKey, user)
|
|
got := UserFromContext(ctx)
|
|
if got == nil || got.ID != 42 {
|
|
t.Errorf("UserFromContext with user: got %v", got)
|
|
}
|
|
|
|
// UserFromContext without user.
|
|
if UserFromContext(context.Background()) != nil {
|
|
t.Error("UserFromContext should return nil for empty context")
|
|
}
|
|
|
|
// CSRFTokenFromContext.
|
|
ctx = context.WithValue(context.Background(), csrfTokenKey, "my-token")
|
|
if CSRFTokenFromContext(ctx) != "my-token" {
|
|
t.Error("CSRFTokenFromContext failed")
|
|
}
|
|
if CSRFTokenFromContext(context.Background()) != "" {
|
|
t.Error("CSRFTokenFromContext should return empty for missing key")
|
|
}
|
|
|
|
// RealIPFromContext.
|
|
ctx = context.WithValue(context.Background(), realIPKey, "1.2.3.4")
|
|
if RealIPFromContext(ctx) != "1.2.3.4" {
|
|
t.Error("RealIPFromContext failed")
|
|
}
|
|
if RealIPFromContext(context.Background()) != "" {
|
|
t.Error("RealIPFromContext should return empty for missing key")
|
|
}
|
|
}
|
|
|
|
func TestMustResetPasswordGuardRedirects(t *testing.T) {
|
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
handler := MustResetPasswordGuard("")(inner)
|
|
|
|
// User with must_reset_password on a regular path → redirect.
|
|
user := &model.User{ID: 1, Username: "resetme", MustResetPassword: true}
|
|
ctx := context.WithValue(context.Background(), userKey, user)
|
|
|
|
req := httptest.NewRequest("GET", "/faves", nil).WithContext(ctx)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusSeeOther {
|
|
t.Errorf("must-reset on /faves: got %d, want 303", rr.Code)
|
|
}
|
|
if loc := rr.Header().Get("Location"); loc != "/reset-password" {
|
|
t.Errorf("redirect location = %q, want /reset-password", loc)
|
|
}
|
|
}
|
|
|
|
func TestMustResetPasswordGuardAllowsPaths(t *testing.T) {
|
|
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
handler := MustResetPasswordGuard("")(inner)
|
|
|
|
user := &model.User{ID: 1, Username: "resetme", MustResetPassword: true}
|
|
ctx := context.WithValue(context.Background(), userKey, user)
|
|
|
|
// These paths should pass through even with must_reset.
|
|
allowedPaths := []string{"/reset-password", "/logout", "/health", "/static/css/style.css"}
|
|
for _, path := range allowedPaths {
|
|
req := httptest.NewRequest("GET", path, nil).WithContext(ctx)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("must-reset on %s: got %d, want 200", path, rr.Code)
|
|
}
|
|
}
|
|
}
|