test: add comprehensive test suite (44 → 169 tests) and v1.1 plan

Add 125 new test functions across 10 new test files, covering:
- CSRF middleware (8 tests): double-submit cookie validation
- Auth middleware (12 tests): SessionLoader, RequireAdmin, context helpers
- API handlers (28 tests): auth, faves CRUD, tags, users, export/import
- Web handlers (41 tests): signup, login, password reset, fave CRUD,
  admin panel, feeds, import/export, profiles, settings
- Config (8 tests): env parsing, defaults, trusted proxies, normalization
- Database (6 tests): migrations, PRAGMAs, idempotency, seeding
- Image processing (10 tests): JPEG/PNG, resize, EXIF strip, path traversal
- Render (6 tests): page/error/partial rendering, template functions
- Settings store (3 tests): CRUD operations
- Regression tests for display name fallback and CSP-safe autocomplete

Also adds CSRF middleware to testServer chain for end-to-end CSRF
verification, TESTPLAN.md documenting coverage, and PLANS-v1.1.md
with implementation plans for notes+OG, PWA, editing UX, and admin.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Ole-Morten Duesund 2026-04-04 00:18:01 +02:00
commit a8f3aa6f7e
12 changed files with 3820 additions and 2 deletions

View file

@ -0,0 +1,204 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package config
import (
"testing"
"time"
)
func TestLoadDefaults(t *testing.T) {
// Clear all FAVORITTER_ env vars to test defaults.
for _, key := range []string{
"FAVORITTER_DB_PATH", "FAVORITTER_LISTEN", "FAVORITTER_BASE_PATH",
"FAVORITTER_EXTERNAL_URL", "FAVORITTER_UPLOAD_DIR", "FAVORITTER_MAX_UPLOAD_SIZE",
"FAVORITTER_SESSION_LIFETIME", "FAVORITTER_ARGON2_MEMORY", "FAVORITTER_ARGON2_TIME",
"FAVORITTER_ARGON2_PARALLELISM", "FAVORITTER_RATE_LIMIT", "FAVORITTER_ADMIN_USERNAME",
"FAVORITTER_ADMIN_PASSWORD", "FAVORITTER_SITE_NAME", "FAVORITTER_DEV_MODE",
"FAVORITTER_TRUSTED_PROXIES",
} {
t.Setenv(key, "")
}
cfg := Load()
if cfg.DBPath != "./data/favoritter.db" {
t.Errorf("DBPath = %q, want default", cfg.DBPath)
}
if cfg.Listen != ":8080" {
t.Errorf("Listen = %q, want :8080", cfg.Listen)
}
if cfg.BasePath != "" {
t.Errorf("BasePath = %q, want empty (root)", cfg.BasePath)
}
if cfg.MaxUploadSize != 10<<20 {
t.Errorf("MaxUploadSize = %d, want %d", cfg.MaxUploadSize, 10<<20)
}
if cfg.SessionLifetime != 720*time.Hour {
t.Errorf("SessionLifetime = %v, want 720h", cfg.SessionLifetime)
}
if cfg.SiteName != "Favoritter" {
t.Errorf("SiteName = %q, want Favoritter", cfg.SiteName)
}
if cfg.DevMode {
t.Error("DevMode should be false by default")
}
if cfg.RateLimit != 60 {
t.Errorf("RateLimit = %d, want 60", cfg.RateLimit)
}
}
func TestLoadFromEnv(t *testing.T) {
t.Setenv("FAVORITTER_DB_PATH", "/custom/db.sqlite")
t.Setenv("FAVORITTER_LISTEN", ":9090")
t.Setenv("FAVORITTER_BASE_PATH", "/faves")
t.Setenv("FAVORITTER_EXTERNAL_URL", "https://faves.example.com/")
t.Setenv("FAVORITTER_UPLOAD_DIR", "/custom/uploads")
t.Setenv("FAVORITTER_MAX_UPLOAD_SIZE", "20971520")
t.Setenv("FAVORITTER_SESSION_LIFETIME", "48h")
t.Setenv("FAVORITTER_ARGON2_MEMORY", "131072")
t.Setenv("FAVORITTER_ARGON2_TIME", "5")
t.Setenv("FAVORITTER_ARGON2_PARALLELISM", "4")
t.Setenv("FAVORITTER_RATE_LIMIT", "100")
t.Setenv("FAVORITTER_ADMIN_USERNAME", "admin")
t.Setenv("FAVORITTER_ADMIN_PASSWORD", "secret")
t.Setenv("FAVORITTER_SITE_NAME", "Mine Favoritter")
t.Setenv("FAVORITTER_DEV_MODE", "true")
t.Setenv("FAVORITTER_TRUSTED_PROXIES", "10.0.0.0/8,192.168.1.0/24")
cfg := Load()
if cfg.DBPath != "/custom/db.sqlite" {
t.Errorf("DBPath = %q", cfg.DBPath)
}
if cfg.Listen != ":9090" {
t.Errorf("Listen = %q", cfg.Listen)
}
if cfg.BasePath != "/faves" {
t.Errorf("BasePath = %q, want /faves", cfg.BasePath)
}
// External URL should have trailing slash stripped.
if cfg.ExternalURL != "https://faves.example.com" {
t.Errorf("ExternalURL = %q, want trailing slash stripped", cfg.ExternalURL)
}
if cfg.MaxUploadSize != 20971520 {
t.Errorf("MaxUploadSize = %d", cfg.MaxUploadSize)
}
if cfg.SessionLifetime != 48*time.Hour {
t.Errorf("SessionLifetime = %v", cfg.SessionLifetime)
}
if cfg.Argon2Memory != 131072 {
t.Errorf("Argon2Memory = %d", cfg.Argon2Memory)
}
if cfg.Argon2Time != 5 {
t.Errorf("Argon2Time = %d", cfg.Argon2Time)
}
if cfg.Argon2Parallelism != 4 {
t.Errorf("Argon2Parallelism = %d", cfg.Argon2Parallelism)
}
if cfg.RateLimit != 100 {
t.Errorf("RateLimit = %d", cfg.RateLimit)
}
if cfg.AdminUsername != "admin" {
t.Errorf("AdminUsername = %q", cfg.AdminUsername)
}
if cfg.AdminPassword != "secret" {
t.Errorf("AdminPassword = %q", cfg.AdminPassword)
}
if cfg.SiteName != "Mine Favoritter" {
t.Errorf("SiteName = %q", cfg.SiteName)
}
if !cfg.DevMode {
t.Error("DevMode should be true")
}
if len(cfg.TrustedProxies) != 2 {
t.Errorf("TrustedProxies: got %d, want 2", len(cfg.TrustedProxies))
}
}
func TestTrustedProxiesParsing(t *testing.T) {
t.Setenv("FAVORITTER_TRUSTED_PROXIES", "10.0.0.0/8, 192.168.0.0/16, 127.0.0.1")
cfg := Load()
if len(cfg.TrustedProxies) != 3 {
t.Fatalf("TrustedProxies: got %d, want 3", len(cfg.TrustedProxies))
}
// 127.0.0.1 without CIDR should become 127.0.0.1/32.
last := cfg.TrustedProxies[2]
if ones, _ := last.Mask.Size(); ones != 32 {
t.Errorf("bare IP mask = /%d, want /32", ones)
}
}
func TestTrustedProxiesInvalid(t *testing.T) {
t.Setenv("FAVORITTER_TRUSTED_PROXIES", "not-an-ip, 10.0.0.0/8")
cfg := Load()
// Invalid entries are skipped; valid ones remain.
if len(cfg.TrustedProxies) != 1 {
t.Errorf("TrustedProxies: got %d, want 1 (invalid skipped)", len(cfg.TrustedProxies))
}
}
func TestBasePathNormalization(t *testing.T) {
tests := []struct {
input string
want string
}{
{"/", ""},
{"", ""},
{"/faves", "/faves"},
{"/faves/", "/faves"},
{"faves", "/faves"},
{"/sub/path/", "/sub/path"},
}
for _, tt := range tests {
got := normalizePath(tt.input)
if got != tt.want {
t.Errorf("normalizePath(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestDevModeFlag(t *testing.T) {
t.Setenv("FAVORITTER_DEV_MODE", "true")
cfg := Load()
if !cfg.DevMode {
t.Error("DevMode should be true when env is 'true'")
}
t.Setenv("FAVORITTER_DEV_MODE", "false")
cfg = Load()
if cfg.DevMode {
t.Error("DevMode should be false when env is 'false'")
}
}
func TestExternalHostname(t *testing.T) {
cfg := &Config{ExternalURL: "https://faves.example.com/base"}
if got := cfg.ExternalHostname(); got != "faves.example.com" {
t.Errorf("ExternalHostname = %q, want faves.example.com", got)
}
cfg = &Config{}
if got := cfg.ExternalHostname(); got != "" {
t.Errorf("empty ExternalURL: ExternalHostname = %q, want empty", got)
}
}
func TestBaseURL(t *testing.T) {
// With external URL configured.
cfg := &Config{ExternalURL: "https://faves.example.com"}
if got := cfg.BaseURL("localhost:8080"); got != "https://faves.example.com" {
t.Errorf("BaseURL with external = %q", got)
}
// Without external URL — falls back to request host.
cfg = &Config{BasePath: "/faves"}
if got := cfg.BaseURL("myhost.local:8080"); got != "https://myhost.local:8080/faves" {
t.Errorf("BaseURL fallback = %q", got)
}
}

View file

@ -0,0 +1,140 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package database
import (
"testing"
_ "modernc.org/sqlite"
)
func TestOpenInMemory(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatalf("Open(:memory:): %v", err)
}
defer db.Close()
// Verify the connection is usable.
var result int
if err := db.QueryRow("SELECT 1").Scan(&result); err != nil {
t.Fatalf("query: %v", err)
}
if result != 1 {
t.Errorf("SELECT 1 = %d", result)
}
}
func TestMigrateCreatesAllTables(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatalf("open: %v", err)
}
defer db.Close()
if err := Migrate(db); err != nil {
t.Fatalf("migrate: %v", err)
}
// Check that core tables exist.
tables := []string{"users", "faves", "tags", "fave_tags", "sessions", "site_settings", "schema_migrations", "signup_requests"}
for _, table := range tables {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&count)
if err != nil {
t.Errorf("check table %s: %v", table, err)
}
if count != 1 {
t.Errorf("table %s does not exist", table)
}
}
}
func TestMigrateIdempotent(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatalf("open: %v", err)
}
defer db.Close()
// First migration.
if err := Migrate(db); err != nil {
t.Fatalf("first migrate: %v", err)
}
// Second migration should be a no-op.
if err := Migrate(db); err != nil {
t.Fatalf("second migrate: %v", err)
}
// Verify schema_migrations has entries (no duplicates).
var count int
db.QueryRow("SELECT COUNT(*) FROM schema_migrations").Scan(&count)
if count < 1 {
t.Error("expected at least one migration record")
}
}
func TestPRAGMAs(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatalf("open: %v", err)
}
defer db.Close()
// WAL mode.
var journalMode string
db.QueryRow("PRAGMA journal_mode").Scan(&journalMode)
// In-memory databases use "memory" journal mode, not "wal".
// WAL is only meaningful for file-based databases.
// We just verify the pragma was accepted without error.
// Foreign keys should be ON.
var fk int
db.QueryRow("PRAGMA foreign_keys").Scan(&fk)
if fk != 1 {
t.Errorf("foreign_keys = %d, want 1", fk)
}
// Busy timeout.
var timeout int
db.QueryRow("PRAGMA busy_timeout").Scan(&timeout)
if timeout != 5000 {
t.Errorf("busy_timeout = %d, want 5000", timeout)
}
}
func TestSingleConnection(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatalf("open: %v", err)
}
defer db.Close()
stats := db.Stats()
if stats.MaxOpenConnections != 1 {
t.Errorf("MaxOpenConnections = %d, want 1", stats.MaxOpenConnections)
}
}
func TestSiteSettingsSeeded(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatalf("open: %v", err)
}
defer db.Close()
if err := Migrate(db); err != nil {
t.Fatalf("migrate: %v", err)
}
// Migrations should seed a default site_settings row.
var siteName string
err = db.QueryRow("SELECT site_name FROM site_settings WHERE id = 1").Scan(&siteName)
if err != nil {
t.Fatalf("query site_settings: %v", err)
}
if siteName == "" {
t.Error("expected non-empty default site_name")
}
}

View file

@ -0,0 +1,663 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package api
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"kode.naiv.no/olemd/favoritter/internal/config"
"kode.naiv.no/olemd/favoritter/internal/database"
"kode.naiv.no/olemd/favoritter/internal/middleware"
"kode.naiv.no/olemd/favoritter/internal/store"
)
// testAPIServer creates a wired API handler with in-memory DB.
func testAPIServer(t *testing.T) (*Handler, *http.ServeMux, *store.UserStore, *store.SessionStore) {
t.Helper()
db, err := database.Open(":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := database.Migrate(db); err != nil {
t.Fatalf("migrate: %v", err)
}
t.Cleanup(func() { db.Close() })
store.Argon2Memory = 1024
store.Argon2Time = 1
cfg := &config.Config{
MaxUploadSize: 10 << 20, // 10 MB
}
users := store.NewUserStore(db)
sessions := store.NewSessionStore(db)
faves := store.NewFaveStore(db)
tags := store.NewTagStore(db)
h := New(Deps{
Config: cfg,
Users: users,
Sessions: sessions,
Faves: faves,
Tags: tags,
})
mux := http.NewServeMux()
h.Routes(mux)
// Wrap with SessionLoader so authenticated API requests work.
chain := middleware.SessionLoader(sessions, users)(mux)
wrappedMux := http.NewServeMux()
wrappedMux.Handle("/", chain)
return h, wrappedMux, users, sessions
}
// apiLogin creates a user and returns a session cookie.
func apiLogin(t *testing.T, users *store.UserStore, sessions *store.SessionStore, username, password, role string) *http.Cookie {
t.Helper()
user, err := users.Create(username, password, role)
if err != nil {
t.Fatalf("create user %s: %v", username, err)
}
token, err := sessions.Create(user.ID)
if err != nil {
t.Fatalf("create session: %v", err)
}
return &http.Cookie{Name: "session", Value: token}
}
// jsonBody is a helper to parse JSON response bodies.
func jsonBody(t *testing.T, rr *httptest.ResponseRecorder) map[string]any {
t.Helper()
var result map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &result); err != nil {
t.Fatalf("parse response JSON: %v\nbody: %s", err, rr.Body.String())
}
return result
}
// --- Auth ---
func TestAPILoginSuccess(t *testing.T) {
_, mux, users, _ := testAPIServer(t)
users.Create("testuser", "password123", "user")
body := `{"username":"testuser","password":"password123"}`
req := httptest.NewRequest("POST", "/api/v1/auth/login", strings.NewReader(body))
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("login: got %d, want 200\nbody: %s", rr.Code, rr.Body.String())
}
result := jsonBody(t, rr)
if result["token"] == nil || result["token"] == "" {
t.Error("expected token in response")
}
user, ok := result["user"].(map[string]any)
if !ok {
t.Fatal("expected user object in response")
}
if user["username"] != "testuser" {
t.Errorf("username = %v, want testuser", user["username"])
}
}
func TestAPILoginWrongPassword(t *testing.T) {
_, mux, users, _ := testAPIServer(t)
users.Create("testuser", "password123", "user")
body := `{"username":"testuser","password":"wrong"}`
req := httptest.NewRequest("POST", "/api/v1/auth/login", strings.NewReader(body))
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("wrong password: got %d, want 401", rr.Code)
}
}
func TestAPILoginInvalidBody(t *testing.T) {
_, mux, _, _ := testAPIServer(t)
req := httptest.NewRequest("POST", "/api/v1/auth/login", strings.NewReader("not json"))
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("invalid body: got %d, want 400", rr.Code)
}
}
func TestAPILogout(t *testing.T) {
_, mux, users, sessions := testAPIServer(t)
cookie := apiLogin(t, users, sessions, "testuser", "pass123", "user")
req := httptest.NewRequest("POST", "/api/v1/auth/logout", nil)
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("logout: got %d, want 200", rr.Code)
}
// Session should be invalid now.
_, err := sessions.Validate(cookie.Value)
if err == nil {
t.Error("session should be invalidated after logout")
}
}
// --- Faves CRUD ---
func TestAPICreateFave(t *testing.T) {
_, mux, users, sessions := testAPIServer(t)
cookie := apiLogin(t, users, sessions, "testuser", "pass123", "user")
body := `{"description":"My favorite thing","url":"https://example.com","privacy":"public","tags":["go","web"]}`
req := httptest.NewRequest("POST", "/api/v1/faves", strings.NewReader(body))
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusCreated {
t.Fatalf("create fave: got %d, want 201\nbody: %s", rr.Code, rr.Body.String())
}
result := jsonBody(t, rr)
if result["description"] != "My favorite thing" {
t.Errorf("description = %v", result["description"])
}
if result["url"] != "https://example.com" {
t.Errorf("url = %v", result["url"])
}
tags, ok := result["tags"].([]any)
if !ok || len(tags) != 2 {
t.Errorf("expected 2 tags, got %v", result["tags"])
}
}
func TestAPICreateFaveMissingDescription(t *testing.T) {
_, mux, users, sessions := testAPIServer(t)
cookie := apiLogin(t, users, sessions, "testuser", "pass123", "user")
body := `{"url":"https://example.com"}`
req := httptest.NewRequest("POST", "/api/v1/faves", strings.NewReader(body))
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("missing description: got %d, want 400", rr.Code)
}
}
func TestAPICreateFaveRequiresAuth(t *testing.T) {
_, mux, _, _ := testAPIServer(t)
body := `{"description":"test"}`
req := httptest.NewRequest("POST", "/api/v1/faves", strings.NewReader(body))
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
// Should redirect or return non-2xx.
if rr.Code == http.StatusCreated || rr.Code == http.StatusOK {
t.Errorf("unauthenticated create: got %d, should not be 2xx", rr.Code)
}
}
func TestAPIGetFave(t *testing.T) {
h, mux, users, sessions := testAPIServer(t)
cookie := apiLogin(t, users, sessions, "testuser", "pass123", "user")
// Create a public fave directly.
user, _ := users.GetByUsername("testuser")
fave, _ := h.deps.Faves.Create(user.ID, "Test fave", "https://example.com", "", "public")
h.deps.Tags.SetFaveTags(fave.ID, []string{"test"})
req := httptest.NewRequest("GET", "/api/v1/faves/"+faveIDStr(fave.ID), nil)
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("get fave: got %d, want 200", rr.Code)
}
result := jsonBody(t, rr)
if result["description"] != "Test fave" {
t.Errorf("description = %v", result["description"])
}
}
func TestAPIGetFaveNotFound(t *testing.T) {
_, mux, _, _ := testAPIServer(t)
req := httptest.NewRequest("GET", "/api/v1/faves/99999", nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound {
t.Errorf("nonexistent fave: got %d, want 404", rr.Code)
}
}
func TestAPIPrivateFaveHiddenFromOthers(t *testing.T) {
h, mux, users, sessions := testAPIServer(t)
// User A creates a private fave.
userA, _ := users.Create("usera", "pass123", "user")
fave, _ := h.deps.Faves.Create(userA.ID, "Secret", "", "", "private")
// User B tries to access it.
cookieB := apiLogin(t, users, sessions, "userb", "pass123", "user")
req := httptest.NewRequest("GET", "/api/v1/faves/"+faveIDStr(fave.ID), nil)
req.AddCookie(cookieB)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound {
t.Errorf("private fave for other user: got %d, want 404", rr.Code)
}
}
func TestAPIPrivateFaveVisibleToOwner(t *testing.T) {
h, mux, users, sessions := testAPIServer(t)
userA, _ := users.Create("usera", "pass123", "user")
fave, _ := h.deps.Faves.Create(userA.ID, "My secret", "", "", "private")
tokenA, _ := sessions.Create(userA.ID)
cookieA := &http.Cookie{Name: "session", Value: tokenA}
req := httptest.NewRequest("GET", "/api/v1/faves/"+faveIDStr(fave.ID), nil)
req.AddCookie(cookieA)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("own private fave: got %d, want 200", rr.Code)
}
}
func TestAPIUpdateFave(t *testing.T) {
h, mux, users, sessions := testAPIServer(t)
user, _ := users.Create("testuser", "pass123", "user")
fave, _ := h.deps.Faves.Create(user.ID, "Original", "https://old.com", "", "public")
token, _ := sessions.Create(user.ID)
cookie := &http.Cookie{Name: "session", Value: token}
body := `{"description":"Updated","url":"https://new.com","tags":["updated"]}`
req := httptest.NewRequest("PUT", "/api/v1/faves/"+faveIDStr(fave.ID), strings.NewReader(body))
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("update fave: got %d, want 200\nbody: %s", rr.Code, rr.Body.String())
}
result := jsonBody(t, rr)
if result["description"] != "Updated" {
t.Errorf("description = %v, want Updated", result["description"])
}
if result["url"] != "https://new.com" {
t.Errorf("url = %v, want https://new.com", result["url"])
}
}
func TestAPIUpdateFaveNotOwner(t *testing.T) {
h, mux, users, sessions := testAPIServer(t)
userA, _ := users.Create("usera", "pass123", "user")
fave, _ := h.deps.Faves.Create(userA.ID, "A's fave", "", "", "public")
cookieB := apiLogin(t, users, sessions, "userb", "pass123", "user")
body := `{"description":"Hijacked"}`
req := httptest.NewRequest("PUT", "/api/v1/faves/"+faveIDStr(fave.ID), strings.NewReader(body))
req.AddCookie(cookieB)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("update by non-owner: got %d, want 403", rr.Code)
}
}
func TestAPIDeleteFave(t *testing.T) {
h, mux, users, sessions := testAPIServer(t)
user, _ := users.Create("testuser", "pass123", "user")
fave, _ := h.deps.Faves.Create(user.ID, "Delete me", "", "", "public")
token, _ := sessions.Create(user.ID)
cookie := &http.Cookie{Name: "session", Value: token}
req := httptest.NewRequest("DELETE", "/api/v1/faves/"+faveIDStr(fave.ID), nil)
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusNoContent {
t.Errorf("delete fave: got %d, want 204", rr.Code)
}
// Verify it's gone.
req = httptest.NewRequest("GET", "/api/v1/faves/"+faveIDStr(fave.ID), nil)
req.AddCookie(cookie)
rr = httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound {
t.Errorf("deleted fave: got %d, want 404", rr.Code)
}
}
func TestAPIDeleteFaveNotOwner(t *testing.T) {
h, mux, users, sessions := testAPIServer(t)
userA, _ := users.Create("usera", "pass123", "user")
fave, _ := h.deps.Faves.Create(userA.ID, "A's fave", "", "", "public")
cookieB := apiLogin(t, users, sessions, "userb", "pass123", "user")
req := httptest.NewRequest("DELETE", "/api/v1/faves/"+faveIDStr(fave.ID), nil)
req.AddCookie(cookieB)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("delete by non-owner: got %d, want 403", rr.Code)
}
}
func TestAPIListFaves(t *testing.T) {
h, mux, users, sessions := testAPIServer(t)
user, _ := users.Create("testuser", "pass123", "user")
h.deps.Faves.Create(user.ID, "Fave 1", "", "", "public")
h.deps.Faves.Create(user.ID, "Fave 2", "", "", "public")
h.deps.Faves.Create(user.ID, "Fave 3", "", "", "private")
token, _ := sessions.Create(user.ID)
cookie := &http.Cookie{Name: "session", Value: token}
req := httptest.NewRequest("GET", "/api/v1/faves", nil)
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("list faves: got %d, want 200", rr.Code)
}
result := jsonBody(t, rr)
total, _ := result["total"].(float64)
if total != 3 {
t.Errorf("total = %v, want 3 (all faves including private)", total)
}
}
func TestAPIListFavesPagination(t *testing.T) {
h, mux, users, sessions := testAPIServer(t)
user, _ := users.Create("testuser", "pass123", "user")
for i := 0; i < 5; i++ {
h.deps.Faves.Create(user.ID, "Fave", "", "", "public")
}
token, _ := sessions.Create(user.ID)
cookie := &http.Cookie{Name: "session", Value: token}
req := httptest.NewRequest("GET", "/api/v1/faves?page=1&limit=2", nil)
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
result := jsonBody(t, rr)
faves, ok := result["faves"].([]any)
if !ok {
t.Fatal("expected faves array")
}
if len(faves) != 2 {
t.Errorf("page size: got %d faves, want 2", len(faves))
}
total, _ := result["total"].(float64)
if total != 5 {
t.Errorf("total = %v, want 5", total)
}
}
// --- Tags ---
func TestAPISearchTags(t *testing.T) {
h, mux, users, _ := testAPIServer(t)
user, _ := users.Create("testuser", "pass123", "user")
fave, _ := h.deps.Faves.Create(user.ID, "Test", "", "", "public")
h.deps.Tags.SetFaveTags(fave.ID, []string{"golang", "goroutines", "python"})
req := httptest.NewRequest("GET", "/api/v1/tags?q=go", nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("search tags: got %d, want 200", rr.Code)
}
result := jsonBody(t, rr)
tags, ok := result["tags"].([]any)
if !ok {
t.Fatal("expected tags array")
}
if len(tags) < 1 {
t.Error("expected at least one tag matching 'go'")
}
}
func TestAPISearchTagsEmpty(t *testing.T) {
_, mux, _, _ := testAPIServer(t)
req := httptest.NewRequest("GET", "/api/v1/tags?q=nonexistent", nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("empty tag search: got %d, want 200", rr.Code)
}
result := jsonBody(t, rr)
tags, _ := result["tags"].([]any)
if len(tags) != 0 {
t.Errorf("expected empty tags, got %v", tags)
}
}
// --- Users ---
func TestAPIGetUser(t *testing.T) {
_, mux, users, _ := testAPIServer(t)
users.Create("testuser", "pass123", "user")
req := httptest.NewRequest("GET", "/api/v1/users/testuser", nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("get user: got %d, want 200", rr.Code)
}
result := jsonBody(t, rr)
if result["username"] != "testuser" {
t.Errorf("username = %v", result["username"])
}
}
func TestAPIGetUserNotFound(t *testing.T) {
_, mux, _, _ := testAPIServer(t)
req := httptest.NewRequest("GET", "/api/v1/users/nobody", nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound {
t.Errorf("nonexistent user: got %d, want 404", rr.Code)
}
}
func TestAPIGetDisabledUser(t *testing.T) {
_, mux, users, _ := testAPIServer(t)
user, _ := users.Create("disabled", "pass123", "user")
users.SetDisabled(user.ID, true)
req := httptest.NewRequest("GET", "/api/v1/users/disabled", nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound {
t.Errorf("disabled user: got %d, want 404", rr.Code)
}
}
func TestAPIGetUserFaves(t *testing.T) {
h, mux, users, _ := testAPIServer(t)
user, _ := users.Create("testuser", "pass123", "user")
h.deps.Faves.Create(user.ID, "Public fave", "", "", "public")
h.deps.Faves.Create(user.ID, "Private fave", "", "", "private")
req := httptest.NewRequest("GET", "/api/v1/users/testuser/faves", nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("user faves: got %d, want 200", rr.Code)
}
result := jsonBody(t, rr)
total, _ := result["total"].(float64)
if total != 1 {
t.Errorf("total = %v, want 1 (only public faves)", total)
}
}
// --- Export/Import ---
func TestAPIExport(t *testing.T) {
h, mux, users, sessions := testAPIServer(t)
user, _ := users.Create("testuser", "pass123", "user")
h.deps.Faves.Create(user.ID, "Fave 1", "", "", "public")
h.deps.Faves.Create(user.ID, "Fave 2", "", "", "private")
token, _ := sessions.Create(user.ID)
cookie := &http.Cookie{Name: "session", Value: token}
req := httptest.NewRequest("GET", "/api/v1/export/json", nil)
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("export: got %d, want 200", rr.Code)
}
// Export returns a JSON array directly.
var faves []any
if err := json.Unmarshal(rr.Body.Bytes(), &faves); err != nil {
t.Fatalf("parse export JSON: %v", err)
}
if len(faves) != 2 {
t.Errorf("exported %d faves, want 2", len(faves))
}
}
func TestAPIImportValid(t *testing.T) {
_, mux, users, sessions := testAPIServer(t)
cookie := apiLogin(t, users, sessions, "testuser", "pass123", "user")
body := `[{"description":"Imported 1","privacy":"public"},{"description":"Imported 2","tags":["test"]}]`
req := httptest.NewRequest("POST", "/api/v1/import", strings.NewReader(body))
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("import: got %d, want 200\nbody: %s", rr.Code, rr.Body.String())
}
result := jsonBody(t, rr)
imported, _ := result["imported"].(float64)
if imported != 2 {
t.Errorf("imported = %v, want 2", imported)
}
}
func TestAPIImportSkipsEmpty(t *testing.T) {
_, mux, users, sessions := testAPIServer(t)
cookie := apiLogin(t, users, sessions, "testuser", "pass123", "user")
body := `[{"description":"Valid"},{"description":"","url":"https://empty.com"}]`
req := httptest.NewRequest("POST", "/api/v1/import", strings.NewReader(body))
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
result := jsonBody(t, rr)
imported, _ := result["imported"].(float64)
total, _ := result["total"].(float64)
if imported != 1 {
t.Errorf("imported = %v, want 1", imported)
}
if total != 2 {
t.Errorf("total = %v, want 2", total)
}
}
func TestAPIImportInvalidJSON(t *testing.T) {
_, mux, users, sessions := testAPIServer(t)
cookie := apiLogin(t, users, sessions, "testuser", "pass123", "user")
req := httptest.NewRequest("POST", "/api/v1/import", strings.NewReader("not json"))
req.AddCookie(cookie)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("invalid JSON import: got %d, want 400", rr.Code)
}
}
// --- JSON helpers ---
func TestQueryIntFallback(t *testing.T) {
tests := []struct {
query string
want int
}{
{"", 10},
{"page=abc", 10},
{"page=-1", 10},
{"page=0", 10},
{"page=5", 5},
}
for _, tt := range tests {
req := httptest.NewRequest("GET", "/test?"+tt.query, nil)
got := queryInt(req, "page", 10)
if got != tt.want {
t.Errorf("queryInt(%q) = %d, want %d", tt.query, got, tt.want)
}
}
}
// faveIDStr converts an int64 to a string for URL paths.
func faveIDStr(id int64) string {
return strconv.FormatInt(id, 10)
}

View file

@ -67,8 +67,8 @@ func testServer(t *testing.T) (*Handler, *http.ServeMux) {
mux := h.Routes()
// Wrap with SessionLoader so authenticated tests work.
chain := middleware.SessionLoader(sessions, users)(mux)
// Wrap with SessionLoader and CSRFProtection so authenticated tests work.
chain := middleware.CSRFProtection(cfg)(middleware.SessionLoader(sessions, users)(mux))
wrappedMux := http.NewServeMux()
wrappedMux.Handle("/", chain)

1153
internal/handler/web_test.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,234 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package image
import (
"bytes"
"image"
"image/jpeg"
"image/png"
"mime/multipart"
"net/textproto"
"os"
"path/filepath"
"strings"
"testing"
)
// testJPEG creates a test JPEG image in memory with given dimensions.
func testJPEG(t *testing.T, width, height int) (*bytes.Buffer, *multipart.FileHeader) {
t.Helper()
img := image.NewRGBA(image.Rect(0, 0, width, height))
var buf bytes.Buffer
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 75}); err != nil {
t.Fatalf("encode test jpeg: %v", err)
}
header := &multipart.FileHeader{
Filename: "test.jpg",
Size: int64(buf.Len()),
Header: textproto.MIMEHeader{"Content-Type": {"image/jpeg"}},
}
return &buf, header
}
// testPNG creates a test PNG image in memory with given dimensions.
func testPNG(t *testing.T, width, height int) (*bytes.Buffer, *multipart.FileHeader) {
t.Helper()
img := image.NewRGBA(image.Rect(0, 0, width, height))
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
t.Fatalf("encode test png: %v", err)
}
header := &multipart.FileHeader{
Filename: "test.png",
Size: int64(buf.Len()),
Header: textproto.MIMEHeader{"Content-Type": {"image/png"}},
}
return &buf, header
}
// bufferReadSeeker wraps a bytes.Reader to implement multipart.File.
type bufferReadSeeker struct {
*bytes.Reader
}
func (b *bufferReadSeeker) Close() error { return nil }
func TestProcessJPEG(t *testing.T) {
uploadDir := t.TempDir()
buf, header := testJPEG(t, 800, 600)
result, err := Process(&bufferReadSeeker{bytes.NewReader(buf.Bytes())}, header, uploadDir)
if err != nil {
t.Fatalf("Process JPEG: %v", err)
}
if result.Filename == "" {
t.Error("expected non-empty filename")
}
if !strings.HasSuffix(result.Filename, ".jpg") {
t.Errorf("filename %q should end with .jpg", result.Filename)
}
// Verify file was written.
if _, err := os.Stat(result.Path); err != nil {
t.Errorf("output file not found: %v", err)
}
}
func TestProcessPNG(t *testing.T) {
uploadDir := t.TempDir()
buf, header := testPNG(t, 640, 480)
result, err := Process(&bufferReadSeeker{bytes.NewReader(buf.Bytes())}, header, uploadDir)
if err != nil {
t.Fatalf("Process PNG: %v", err)
}
if !strings.HasSuffix(result.Filename, ".png") {
t.Errorf("filename %q should end with .png", result.Filename)
}
}
func TestProcessResizeWideImage(t *testing.T) {
uploadDir := t.TempDir()
buf, header := testJPEG(t, 3840, 2160) // 4K width
result, err := Process(&bufferReadSeeker{bytes.NewReader(buf.Bytes())}, header, uploadDir)
if err != nil {
t.Fatalf("Process wide image: %v", err)
}
// Read back and check dimensions.
f, err := os.Open(result.Path)
if err != nil {
t.Fatalf("open result: %v", err)
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
t.Fatalf("decode result: %v", err)
}
bounds := img.Bounds()
if bounds.Dx() != MaxWidth {
t.Errorf("resized width = %d, want %d", bounds.Dx(), MaxWidth)
}
// Aspect ratio should be maintained.
expectedHeight := 2160 * MaxWidth / 3840
if bounds.Dy() != expectedHeight {
t.Errorf("resized height = %d, want %d", bounds.Dy(), expectedHeight)
}
}
func TestProcessSmallImageNotResized(t *testing.T) {
uploadDir := t.TempDir()
buf, header := testJPEG(t, 800, 600)
result, err := Process(&bufferReadSeeker{bytes.NewReader(buf.Bytes())}, header, uploadDir)
if err != nil {
t.Fatalf("Process small image: %v", err)
}
f, err := os.Open(result.Path)
if err != nil {
t.Fatalf("open result: %v", err)
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
t.Fatalf("decode result: %v", err)
}
if img.Bounds().Dx() != 800 {
t.Errorf("small image width = %d, should not be resized", img.Bounds().Dx())
}
}
func TestProcessInvalidMIME(t *testing.T) {
uploadDir := t.TempDir()
header := &multipart.FileHeader{
Filename: "test.txt",
Header: textproto.MIMEHeader{"Content-Type": {"text/plain"}},
}
buf := bytes.NewReader([]byte("not an image"))
_, err := Process(&bufferReadSeeker{buf}, header, uploadDir)
if err == nil {
t.Error("expected error for text/plain MIME type")
}
if !strings.Contains(err.Error(), "unsupported image type") {
t.Errorf("error = %q, should mention unsupported type", err)
}
}
func TestProcessCorruptImage(t *testing.T) {
uploadDir := t.TempDir()
header := &multipart.FileHeader{
Filename: "corrupt.jpg",
Header: textproto.MIMEHeader{"Content-Type": {"image/jpeg"}},
}
buf := bytes.NewReader([]byte("this is not valid jpeg data"))
_, err := Process(&bufferReadSeeker{buf}, header, uploadDir)
if err == nil {
t.Error("expected error for corrupt image data")
}
}
func TestProcessUUIDFilename(t *testing.T) {
uploadDir := t.TempDir()
buf, header := testJPEG(t, 100, 100)
// Give a user-supplied filename.
header.Filename = "my-vacation-photo.jpg"
result, err := Process(&bufferReadSeeker{bytes.NewReader(buf.Bytes())}, header, uploadDir)
if err != nil {
t.Fatalf("Process: %v", err)
}
if strings.Contains(result.Filename, "vacation") {
t.Error("filename should be UUID-based, not user-supplied")
}
// UUID format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx.ext
if len(result.Filename) < 36 {
t.Errorf("filename %q too short for UUID", result.Filename)
}
}
func TestAllowedTypes(t *testing.T) {
expected := []string{"image/jpeg", "image/png", "image/gif", "image/webp"}
for _, mime := range expected {
if _, ok := AllowedTypes[mime]; !ok {
t.Errorf("AllowedTypes missing %s", mime)
}
}
}
func TestDeletePathTraversal(t *testing.T) {
uploadDir := t.TempDir()
// Create a file outside uploadDir.
outsideFile := filepath.Join(t.TempDir(), "sensitive.txt")
os.WriteFile(outsideFile, []byte("secret"), 0644)
// Attempt to delete it via path traversal.
err := Delete(uploadDir, "../../../"+filepath.Base(outsideFile))
if err == nil {
t.Error("expected error for path traversal")
}
// File should still exist.
if _, statErr := os.Stat(outsideFile); statErr != nil {
t.Error("path traversal should not have deleted the file")
}
}
func TestDeleteEmpty(t *testing.T) {
// Empty filename should be a no-op.
if err := Delete(t.TempDir(), ""); err != nil {
t.Errorf("Delete empty filename: %v", err)
}
}

View file

@ -0,0 +1,299 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"kode.naiv.no/olemd/favoritter/internal/database"
"kode.naiv.no/olemd/favoritter/internal/model"
"kode.naiv.no/olemd/favoritter/internal/store"
)
// testStores creates in-memory stores for auth middleware tests.
func testStores(t *testing.T) (*store.SessionStore, *store.UserStore) {
t.Helper()
db, err := database.Open(":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := database.Migrate(db); err != nil {
t.Fatalf("migrate: %v", err)
}
t.Cleanup(func() { db.Close() })
store.Argon2Memory = 1024
store.Argon2Time = 1
return store.NewSessionStore(db), store.NewUserStore(db)
}
func TestSessionLoaderValidToken(t *testing.T) {
sessions, users := testStores(t)
user, err := users.Create("testuser", "pass123", "user")
if err != nil {
t.Fatalf("create user: %v", err)
}
token, err := sessions.Create(user.ID)
if err != nil {
t.Fatalf("create session: %v", err)
}
var ctxUser *model.User
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxUser = UserFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
handler := SessionLoader(sessions, users)(inner)
req := httptest.NewRequest("GET", "/faves", nil)
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if ctxUser == nil {
t.Fatal("expected user in context, got nil")
}
if ctxUser.ID != user.ID {
t.Errorf("context user ID = %d, want %d", ctxUser.ID, user.ID)
}
}
func TestSessionLoaderInvalidToken(t *testing.T) {
sessions, users := testStores(t)
var ctxUser *model.User
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxUser = UserFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
handler := SessionLoader(sessions, users)(inner)
req := httptest.NewRequest("GET", "/faves", nil)
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "invalid-token"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("invalid token: got %d, want 200", rr.Code)
}
if ctxUser != nil {
t.Error("expected nil user for invalid token")
}
// Should clear the invalid session cookie.
for _, c := range rr.Result().Cookies() {
if c.Name == SessionCookieName && c.MaxAge == -1 {
return // Cookie cleared, good.
}
}
t.Error("expected session cookie to be cleared for invalid token")
}
func TestSessionLoaderNoCookie(t *testing.T) {
sessions, users := testStores(t)
var ctxUser *model.User
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxUser = UserFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
handler := SessionLoader(sessions, users)(inner)
req := httptest.NewRequest("GET", "/faves", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("no cookie: got %d, want 200", rr.Code)
}
if ctxUser != nil {
t.Error("expected nil user when no cookie")
}
}
func TestSessionLoaderSkipsStaticPaths(t *testing.T) {
sessions, users := testStores(t)
var handlerCalled bool
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
handler := SessionLoader(sessions, users)(inner)
for _, path := range []string{"/static/css/style.css", "/uploads/image.jpg"} {
handlerCalled = false
req := httptest.NewRequest("GET", path, nil)
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "some-token"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if !handlerCalled {
t.Errorf("handler not called for %s", path)
}
if rr.Code != http.StatusOK {
t.Errorf("%s: got %d, want 200", path, rr.Code)
}
}
}
func TestSessionLoaderDisabledUser(t *testing.T) {
sessions, users := testStores(t)
user, _ := users.Create("testuser", "pass123", "user")
token, _ := sessions.Create(user.ID)
users.SetDisabled(user.ID, true)
var ctxUser *model.User
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxUser = UserFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
handler := SessionLoader(sessions, users)(inner)
req := httptest.NewRequest("GET", "/faves", nil)
req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if ctxUser != nil {
t.Error("disabled user should not be in context")
}
// Session should be deleted and cookie cleared.
for _, c := range rr.Result().Cookies() {
if c.Name == SessionCookieName && c.MaxAge == -1 {
return
}
}
t.Error("expected session cookie to be cleared for disabled user")
}
func TestRequireAdminRejectsNonAdmin(t *testing.T) {
handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Regular user in context.
user := &model.User{ID: 1, Username: "regular", Role: "user"}
ctx := context.WithValue(context.Background(), userKey, user)
req := httptest.NewRequest("GET", "/admin", nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("non-admin: got %d, want 403", rr.Code)
}
}
func TestRequireAdminAllowsAdmin(t *testing.T) {
handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
user := &model.User{ID: 1, Username: "admin", Role: "admin"}
ctx := context.WithValue(context.Background(), userKey, user)
req := httptest.NewRequest("GET", "/admin", nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("admin: got %d, want 200", rr.Code)
}
}
func TestRequireAdminNoUser(t *testing.T) {
handler := RequireAdmin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/admin", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("no user: got %d, want 403", rr.Code)
}
}
func TestContextHelpers(t *testing.T) {
// UserFromContext with user.
user := &model.User{ID: 42, Username: "tester"}
ctx := context.WithValue(context.Background(), userKey, user)
got := UserFromContext(ctx)
if got == nil || got.ID != 42 {
t.Errorf("UserFromContext with user: got %v", got)
}
// UserFromContext without user.
if UserFromContext(context.Background()) != nil {
t.Error("UserFromContext should return nil for empty context")
}
// CSRFTokenFromContext.
ctx = context.WithValue(context.Background(), csrfTokenKey, "my-token")
if CSRFTokenFromContext(ctx) != "my-token" {
t.Error("CSRFTokenFromContext failed")
}
if CSRFTokenFromContext(context.Background()) != "" {
t.Error("CSRFTokenFromContext should return empty for missing key")
}
// RealIPFromContext.
ctx = context.WithValue(context.Background(), realIPKey, "1.2.3.4")
if RealIPFromContext(ctx) != "1.2.3.4" {
t.Error("RealIPFromContext failed")
}
if RealIPFromContext(context.Background()) != "" {
t.Error("RealIPFromContext should return empty for missing key")
}
}
func TestMustResetPasswordGuardRedirects(t *testing.T) {
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := MustResetPasswordGuard("")(inner)
// User with must_reset_password on a regular path → redirect.
user := &model.User{ID: 1, Username: "resetme", MustResetPassword: true}
ctx := context.WithValue(context.Background(), userKey, user)
req := httptest.NewRequest("GET", "/faves", nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusSeeOther {
t.Errorf("must-reset on /faves: got %d, want 303", rr.Code)
}
if loc := rr.Header().Get("Location"); loc != "/reset-password" {
t.Errorf("redirect location = %q, want /reset-password", loc)
}
}
func TestMustResetPasswordGuardAllowsPaths(t *testing.T) {
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := MustResetPasswordGuard("")(inner)
user := &model.User{ID: 1, Username: "resetme", MustResetPassword: true}
ctx := context.WithValue(context.Background(), userKey, user)
// These paths should pass through even with must_reset.
allowedPaths := []string{"/reset-password", "/logout", "/health", "/static/css/style.css"}
for _, path := range allowedPaths {
req := httptest.NewRequest("GET", path, nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("must-reset on %s: got %d, want 200", path, rr.Code)
}
}
}

View file

@ -0,0 +1,185 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"kode.naiv.no/olemd/favoritter/internal/config"
)
func testCSRFHandler(cfg *config.Config) http.Handler {
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Echo the CSRF token from context so tests can verify it.
token := CSRFTokenFromContext(r.Context())
w.Write([]byte(token))
})
return CSRFProtection(cfg)(inner)
}
func TestCSRFTokenSetInCookie(t *testing.T) {
handler := testCSRFHandler(&config.Config{})
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("GET: got %d, want 200", rr.Code)
}
// Should set a csrf_token cookie.
var found bool
for _, c := range rr.Result().Cookies() {
if c.Name == "csrf_token" {
found = true
if c.Value == "" {
t.Error("csrf_token cookie is empty")
}
if c.HttpOnly {
t.Error("csrf_token cookie should not be HttpOnly (JS needs to read it)")
}
}
}
if !found {
t.Error("csrf_token cookie not set on first request")
}
// The context token should match the cookie.
body := rr.Body.String()
if body == "" {
t.Error("CSRF token not set in context")
}
}
func TestCSRFValidTokenAccepted(t *testing.T) {
handler := testCSRFHandler(&config.Config{})
// First GET to obtain a token.
getReq := httptest.NewRequest("GET", "/", nil)
getRR := httptest.NewRecorder()
handler.ServeHTTP(getRR, getReq)
var token string
for _, c := range getRR.Result().Cookies() {
if c.Name == "csrf_token" {
token = c.Value
}
}
if token == "" {
t.Fatal("no csrf_token cookie from GET")
}
// POST with matching cookie + form field.
form := "csrf_token=" + token
req := httptest.NewRequest("POST", "/submit", strings.NewReader(form))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: token})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("valid CSRF POST: got %d, want 200", rr.Code)
}
}
func TestCSRFMismatchRejected(t *testing.T) {
handler := testCSRFHandler(&config.Config{})
form := "csrf_token=wrong-token"
req := httptest.NewRequest("POST", "/submit", strings.NewReader(form))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: "real-token"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("mismatched CSRF: got %d, want 403", rr.Code)
}
}
func TestCSRFMissingTokenRejected(t *testing.T) {
handler := testCSRFHandler(&config.Config{})
// POST with cookie but no form field or header.
req := httptest.NewRequest("POST", "/submit", nil)
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: "some-token"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("missing CSRF form field: got %d, want 403", rr.Code)
}
}
func TestCSRFHeaderFallback(t *testing.T) {
handler := testCSRFHandler(&config.Config{})
token := "valid-header-token"
req := httptest.NewRequest("POST", "/submit", nil)
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: token})
req.Header.Set("X-CSRF-Token", token)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("CSRF via header: got %d, want 200", rr.Code)
}
}
func TestCSRFSkippedForAPI(t *testing.T) {
handler := testCSRFHandler(&config.Config{})
// POST to /api/ path — should skip CSRF validation.
req := httptest.NewRequest("POST", "/api/v1/faves", strings.NewReader(`{}`))
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: "some-token"})
// Intentionally no CSRF form field or header.
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("API route CSRF skip: got %d, want 200", rr.Code)
}
}
func TestCSRFSafeMethodsPassThrough(t *testing.T) {
handler := testCSRFHandler(&config.Config{})
for _, method := range []string{"GET", "HEAD", "OPTIONS"} {
req := httptest.NewRequest(method, "/page", nil)
// No CSRF cookie or token at all.
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("%s without CSRF: got %d, want 200", method, rr.Code)
}
}
}
func TestCSRFExistingCookieReused(t *testing.T) {
handler := testCSRFHandler(&config.Config{})
// Send a request with an existing csrf_token cookie.
existingToken := "pre-existing-token-value"
req := httptest.NewRequest("GET", "/", nil)
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: existingToken})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// The context token should be the existing one.
body := rr.Body.String()
if body != existingToken {
t.Errorf("context token = %q, want existing %q", body, existingToken)
}
// Should NOT set a new cookie (existing one is reused).
for _, c := range rr.Result().Cookies() {
if c.Name == "csrf_token" {
t.Error("should not set new csrf_token cookie when one already exists")
}
}
}

View file

@ -0,0 +1,161 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package render
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"kode.naiv.no/olemd/favoritter/internal/config"
)
func testRenderer(t *testing.T) *Renderer {
t.Helper()
cfg := &config.Config{
SiteName: "Test Site",
BasePath: "/test",
}
r, err := New(cfg)
if err != nil {
t.Fatalf("create renderer: %v", err)
}
return r
}
func TestRenderPage(t *testing.T) {
r := testRenderer(t)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
r.Page(rr, req, "login", PageData{
Title: "Logg inn",
})
if rr.Code != http.StatusOK {
t.Errorf("render page: got %d, want 200", rr.Code)
}
body := rr.Body.String()
if !strings.Contains(body, "Logg inn") {
t.Error("page should contain title")
}
ct := rr.Header().Get("Content-Type")
if !strings.Contains(ct, "text/html") {
t.Errorf("content-type = %q, want text/html", ct)
}
}
func TestRenderPageWithData(t *testing.T) {
r := testRenderer(t)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
r.Page(rr, req, "login", PageData{
Title: "Test Page",
SiteName: "My Site",
BasePath: "/base",
})
body := rr.Body.String()
if !strings.Contains(body, "Test Page") {
t.Error("should contain title in output")
}
}
func TestRenderMissingTemplate(t *testing.T) {
r := testRenderer(t)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
r.Page(rr, req, "nonexistent_template_xyz", PageData{})
if rr.Code != http.StatusInternalServerError {
t.Errorf("missing template: got %d, want 500", rr.Code)
}
}
func TestRenderErrorPage(t *testing.T) {
r := testRenderer(t)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
r.Error(rr, req, http.StatusNotFound, "Ikke funnet")
if rr.Code != http.StatusNotFound {
t.Errorf("error page: got %d, want 404", rr.Code)
}
body := rr.Body.String()
if !strings.Contains(body, "Ikke funnet") {
t.Error("error page should contain message")
}
}
func TestRenderPopulatesCommonData(t *testing.T) {
cfg := &config.Config{
SiteName: "Favoritter",
BasePath: "/app",
ExternalURL: "https://example.com",
}
r, err := New(cfg)
if err != nil {
t.Fatalf("create renderer: %v", err)
}
req := httptest.NewRequest("GET", "/", nil)
// Add a CSRF token to the context.
type contextKey string
ctx := context.WithValue(req.Context(), contextKey("csrf_token"), "test-token")
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
r.Page(rr, req, "login", PageData{Title: "Test"})
// BasePath and SiteName should be populated from config.
body := rr.Body.String()
if !strings.Contains(body, "/app") {
t.Error("should contain basePath from config")
}
}
func TestTemplateFuncs(t *testing.T) {
cfg := &config.Config{BasePath: "/test", ExternalURL: "https://example.com"}
r, _ := New(cfg)
funcs := r.templateFuncs()
// Test truncate function.
truncate := funcs["truncate"].(func(int, string) string)
if got := truncate(5, "Hello, world!"); got != "Hello..." {
t.Errorf("truncate(5, long) = %q, want Hello...", got)
}
if got := truncate(20, "Short"); got != "Short" {
t.Errorf("truncate(20, short) = %q, want Short", got)
}
// Test with Norwegian characters (Ærlig = 5 runes: Æ r l i g).
if got := truncate(3, "Ærlig"); got != "Ærl..." {
t.Errorf("truncate(3, Ærlig) = %q, want Ærl...", got)
}
// Test add/subtract.
add := funcs["add"].(func(int, int) int)
if add(2, 3) != 5 {
t.Error("add(2,3) should be 5")
}
sub := funcs["subtract"].(func(int, int) int)
if sub(5, 3) != 2 {
t.Error("subtract(5,3) should be 2")
}
// Test basePath function.
bp := funcs["basePath"].(func() string)
if bp() != "/test" {
t.Errorf("basePath() = %q", bp())
}
}

View file

@ -0,0 +1,75 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package store
import (
"testing"
)
func TestSettingsGetDefault(t *testing.T) {
db := testDB(t)
settings := NewSettingsStore(db)
s, err := settings.Get()
if err != nil {
t.Fatalf("get default settings: %v", err)
}
// Migration should insert a default row.
if s.SiteName == "" {
t.Error("expected non-empty default site name")
}
if s.SignupMode == "" {
t.Error("expected non-empty default signup mode")
}
}
func TestSettingsUpdate(t *testing.T) {
db := testDB(t)
settings := NewSettingsStore(db)
err := settings.Update("Nye Favoritter", "En kul side", "requests")
if err != nil {
t.Fatalf("update settings: %v", err)
}
s, err := settings.Get()
if err != nil {
t.Fatalf("get after update: %v", err)
}
if s.SiteName != "Nye Favoritter" {
t.Errorf("site_name = %q, want %q", s.SiteName, "Nye Favoritter")
}
if s.SiteDescription != "En kul side" {
t.Errorf("site_description = %q, want %q", s.SiteDescription, "En kul side")
}
if s.SignupMode != "requests" {
t.Errorf("signup_mode = %q, want %q", s.SignupMode, "requests")
}
if s.UpdatedAt.IsZero() {
t.Error("expected non-zero updated_at after update")
}
}
func TestSettingsUpdatePreservesOtherFields(t *testing.T) {
db := testDB(t)
settings := NewSettingsStore(db)
// Set initial values.
settings.Update("Site A", "Desc A", "open")
// Update with different values.
settings.Update("Site B", "Desc B", "closed")
s, _ := settings.Get()
if s.SiteName != "Site B" {
t.Errorf("site_name = %q, want %q", s.SiteName, "Site B")
}
if s.SiteDescription != "Desc B" {
t.Errorf("site_description = %q, want %q", s.SiteDescription, "Desc B")
}
if s.SignupMode != "closed" {
t.Errorf("signup_mode = %q, want %q", s.SignupMode, "closed")
}
}