feat: implement Phase 1 (auth) and Phase 2 (faves CRUD) foundation
Go backend with server-rendered HTML/HTMX frontend, SQLite database, and filesystem image storage. Self-hostable single-binary architecture. Phase 1 — Authentication & project foundation: - Argon2id password hashing with timing-attack prevention - Session management with cookie-based auth and periodic cleanup - Login, signup (open/requests/closed modes), logout, forced password reset - CSRF double-submit cookie pattern with HTMX auto-inclusion - Proxy-aware real IP extraction (WireGuard/Tailscale support) - Configurable base path for subdomain and subpath deployment - Rate limiting on auth endpoints with background cleanup - Security headers (CSP, X-Frame-Options, Referrer-Policy) - Structured logging with slog, graceful shutdown - Pico CSS + HTMX vendored and embedded via go:embed Phase 2 — Faves CRUD with tags and images: - Full CRUD for favorites with ownership checks - Image upload with EXIF stripping, resize to 1920px, UUID filenames - Tag system with HTMX autocomplete (prefix search, popularity-sorted) - Privacy controls (public/private per fave, user-configurable default) - Tag browsing, pagination, batch tag loading (avoids N+1) - OpenGraph meta tags on public fave detail pages Includes code quality pass: extracted shared helpers, fixed signup request persistence bug, plugged rate limiter memory leak, removed dead code, and logged previously-swallowed errors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
commit
fc1f7259c5
52 changed files with 5459 additions and 0 deletions
276
internal/store/fave.go
Normal file
276
internal/store/fave.go
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kode.naiv.no/olemd/favoritter/internal/model"
|
||||
)
|
||||
|
||||
var ErrFaveNotFound = errors.New("fave not found")
|
||||
|
||||
type FaveStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewFaveStore(db *sql.DB) *FaveStore {
|
||||
return &FaveStore{db: db}
|
||||
}
|
||||
|
||||
// Create inserts a new fave and returns it with its ID populated.
|
||||
func (s *FaveStore) Create(userID int64, description, url, imagePath, privacy string) (*model.Fave, error) {
|
||||
result, err := s.db.Exec(
|
||||
`INSERT INTO faves (user_id, description, url, image_path, privacy)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
userID, description, url, imagePath, privacy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert fave: %w", err)
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
return s.GetByID(id)
|
||||
}
|
||||
|
||||
// GetByID returns a fave by its ID, including joined user info.
|
||||
func (s *FaveStore) GetByID(id int64) (*model.Fave, error) {
|
||||
f := &model.Fave{}
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(
|
||||
`SELECT f.id, f.user_id, f.description, f.url, f.image_path, f.privacy,
|
||||
f.created_at, f.updated_at, u.username, u.display_name
|
||||
FROM faves f
|
||||
JOIN users u ON u.id = f.user_id
|
||||
WHERE f.id = ?`, id,
|
||||
).Scan(
|
||||
&f.ID, &f.UserID, &f.Description, &f.URL, &f.ImagePath, &f.Privacy,
|
||||
&createdAt, &updatedAt, &f.Username, &f.DisplayName,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrFaveNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query fave: %w", err)
|
||||
}
|
||||
f.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
f.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Update modifies an existing fave's fields.
|
||||
func (s *FaveStore) Update(id int64, description, url, imagePath, privacy string) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE faves SET description = ?, url = ?, image_path = ?, privacy = ?,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
|
||||
WHERE id = ?`,
|
||||
description, url, imagePath, privacy, id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete removes a fave by its ID. The cascade will clean up fave_tags.
|
||||
func (s *FaveStore) Delete(id int64) error {
|
||||
result, err := s.db.Exec("DELETE FROM faves WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
if n == 0 {
|
||||
return ErrFaveNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByUser returns all faves for a user (both public and private),
|
||||
// ordered by newest first, with pagination.
|
||||
func (s *FaveStore) ListByUser(userID int64, limit, offset int) ([]*model.Fave, int, error) {
|
||||
var total int
|
||||
err := s.db.QueryRow("SELECT COUNT(*) FROM faves WHERE user_id = ?", userID).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT f.id, f.user_id, f.description, f.url, f.image_path, f.privacy,
|
||||
f.created_at, f.updated_at, u.username, u.display_name
|
||||
FROM faves f
|
||||
JOIN users u ON u.id = f.user_id
|
||||
WHERE f.user_id = ?
|
||||
ORDER BY f.created_at DESC
|
||||
LIMIT ? OFFSET ?`,
|
||||
userID, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
faves, err := s.scanFaves(rows)
|
||||
return faves, total, err
|
||||
}
|
||||
|
||||
// ListPublicByUser returns only public faves for a user, with pagination.
|
||||
func (s *FaveStore) ListPublicByUser(userID int64, limit, offset int) ([]*model.Fave, int, error) {
|
||||
var total int
|
||||
err := s.db.QueryRow(
|
||||
"SELECT COUNT(*) FROM faves WHERE user_id = ? AND privacy = 'public'", userID,
|
||||
).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT f.id, f.user_id, f.description, f.url, f.image_path, f.privacy,
|
||||
f.created_at, f.updated_at, u.username, u.display_name
|
||||
FROM faves f
|
||||
JOIN users u ON u.id = f.user_id
|
||||
WHERE f.user_id = ? AND f.privacy = 'public'
|
||||
ORDER BY f.created_at DESC
|
||||
LIMIT ? OFFSET ?`,
|
||||
userID, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
faves, err := s.scanFaves(rows)
|
||||
return faves, total, err
|
||||
}
|
||||
|
||||
// ListPublic returns all public faves across all users, with pagination.
|
||||
func (s *FaveStore) ListPublic(limit, offset int) ([]*model.Fave, int, error) {
|
||||
var total int
|
||||
err := s.db.QueryRow("SELECT COUNT(*) FROM faves WHERE privacy = 'public'").Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT f.id, f.user_id, f.description, f.url, f.image_path, f.privacy,
|
||||
f.created_at, f.updated_at, u.username, u.display_name
|
||||
FROM faves f
|
||||
JOIN users u ON u.id = f.user_id
|
||||
WHERE f.privacy = 'public'
|
||||
ORDER BY f.created_at DESC
|
||||
LIMIT ? OFFSET ?`,
|
||||
limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
faves, err := s.scanFaves(rows)
|
||||
return faves, total, err
|
||||
}
|
||||
|
||||
// ListByTag returns all public faves with a given tag, with pagination.
|
||||
func (s *FaveStore) ListByTag(tagName string, limit, offset int) ([]*model.Fave, int, error) {
|
||||
var total int
|
||||
err := s.db.QueryRow(
|
||||
`SELECT COUNT(*) FROM faves f
|
||||
JOIN fave_tags ft ON ft.fave_id = f.id
|
||||
JOIN tags t ON t.id = ft.tag_id
|
||||
WHERE t.name = ? AND f.privacy = 'public'`, tagName,
|
||||
).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT f.id, f.user_id, f.description, f.url, f.image_path, f.privacy,
|
||||
f.created_at, f.updated_at, u.username, u.display_name
|
||||
FROM faves f
|
||||
JOIN users u ON u.id = f.user_id
|
||||
JOIN fave_tags ft ON ft.fave_id = f.id
|
||||
JOIN tags t ON t.id = ft.tag_id
|
||||
WHERE t.name = ? AND f.privacy = 'public'
|
||||
ORDER BY f.created_at DESC
|
||||
LIMIT ? OFFSET ?`,
|
||||
tagName, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
faves, err := s.scanFaves(rows)
|
||||
return faves, total, err
|
||||
}
|
||||
|
||||
// Count returns the total number of faves.
|
||||
func (s *FaveStore) Count() (int, error) {
|
||||
var n int
|
||||
err := s.db.QueryRow("SELECT COUNT(*) FROM faves").Scan(&n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// LoadTags populates the Tags field on each fave.
|
||||
func (s *FaveStore) LoadTags(faves []*model.Fave) error {
|
||||
if len(faves) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build a map for fast lookup.
|
||||
faveMap := make(map[int64]*model.Fave, len(faves))
|
||||
ids := make([]any, len(faves))
|
||||
placeholders := make([]string, len(faves))
|
||||
for i, f := range faves {
|
||||
faveMap[f.ID] = f
|
||||
ids[i] = f.ID
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
`SELECT ft.fave_id, t.id, t.name
|
||||
FROM fave_tags ft
|
||||
JOIN tags t ON t.id = ft.tag_id
|
||||
WHERE ft.fave_id IN (%s)
|
||||
ORDER BY t.name COLLATE NOCASE`,
|
||||
strings.Join(placeholders, ","),
|
||||
)
|
||||
|
||||
rows, err := s.db.Query(query, ids...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load tags: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var faveID int64
|
||||
var tag model.Tag
|
||||
if err := rows.Scan(&faveID, &tag.ID, &tag.Name); err != nil {
|
||||
return fmt.Errorf("scan tag: %w", err)
|
||||
}
|
||||
if f, ok := faveMap[faveID]; ok {
|
||||
f.Tags = append(f.Tags, tag)
|
||||
}
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (s *FaveStore) scanFaves(rows *sql.Rows) ([]*model.Fave, error) {
|
||||
var faves []*model.Fave
|
||||
for rows.Next() {
|
||||
f := &model.Fave{}
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&f.ID, &f.UserID, &f.Description, &f.URL, &f.ImagePath, &f.Privacy,
|
||||
&createdAt, &updatedAt, &f.Username, &f.DisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan fave: %w", err)
|
||||
}
|
||||
f.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
f.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
faves = append(faves, f)
|
||||
}
|
||||
return faves, rows.Err()
|
||||
}
|
||||
|
||||
180
internal/store/fave_test.go
Normal file
180
internal/store/fave_test.go
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFaveCRUD(t *testing.T) {
|
||||
db := testDB(t)
|
||||
users := NewUserStore(db)
|
||||
faves := NewFaveStore(db)
|
||||
tags := NewTagStore(db)
|
||||
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() { Argon2Memory = 65536; Argon2Time = 3 }()
|
||||
|
||||
// Create a user first.
|
||||
user, err := users.Create("testuser", "password123", "user")
|
||||
if err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
|
||||
// Create a fave.
|
||||
fave, err := faves.Create(user.ID, "Blade Runner 2049", "https://example.com", "", "public")
|
||||
if err != nil {
|
||||
t.Fatalf("create fave: %v", err)
|
||||
}
|
||||
if fave.Description != "Blade Runner 2049" {
|
||||
t.Errorf("description = %q, want %q", fave.Description, "Blade Runner 2049")
|
||||
}
|
||||
if fave.Username != "testuser" {
|
||||
t.Errorf("username = %q, want %q", fave.Username, "testuser")
|
||||
}
|
||||
|
||||
// Get by ID.
|
||||
got, err := faves.GetByID(fave.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get fave: %v", err)
|
||||
}
|
||||
if got.Description != fave.Description {
|
||||
t.Errorf("got description = %q, want %q", got.Description, fave.Description)
|
||||
}
|
||||
|
||||
// Update.
|
||||
err = faves.Update(fave.ID, "Blade Runner 2049 (Final Cut)", "https://example.com/br2049", "", "private")
|
||||
if err != nil {
|
||||
t.Fatalf("update fave: %v", err)
|
||||
}
|
||||
updated, _ := faves.GetByID(fave.ID)
|
||||
if updated.Description != "Blade Runner 2049 (Final Cut)" {
|
||||
t.Errorf("updated description = %q", updated.Description)
|
||||
}
|
||||
if updated.Privacy != "private" {
|
||||
t.Errorf("updated privacy = %q, want private", updated.Privacy)
|
||||
}
|
||||
|
||||
// Set tags.
|
||||
err = tags.SetFaveTags(fave.ID, []string{"film", "sci-fi", "favoritt"})
|
||||
if err != nil {
|
||||
t.Fatalf("set tags: %v", err)
|
||||
}
|
||||
faveTags, err := tags.ForFave(fave.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("for fave: %v", err)
|
||||
}
|
||||
if len(faveTags) != 3 {
|
||||
t.Errorf("tag count = %d, want 3", len(faveTags))
|
||||
}
|
||||
|
||||
// List by user.
|
||||
list, total, err := faves.ListByUser(user.ID, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("list by user: %v", err)
|
||||
}
|
||||
if total != 1 || len(list) != 1 {
|
||||
t.Errorf("list by user: total=%d, len=%d", total, len(list))
|
||||
}
|
||||
|
||||
// Load tags for list.
|
||||
err = faves.LoadTags(list)
|
||||
if err != nil {
|
||||
t.Fatalf("load tags: %v", err)
|
||||
}
|
||||
if len(list[0].Tags) != 3 {
|
||||
t.Errorf("loaded tags = %d, want 3", len(list[0].Tags))
|
||||
}
|
||||
|
||||
// Public list should be empty (fave is now private).
|
||||
pubList, pubTotal, err := faves.ListPublicByUser(user.ID, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("list public: %v", err)
|
||||
}
|
||||
if pubTotal != 0 || len(pubList) != 0 {
|
||||
t.Errorf("public list should be empty: total=%d, len=%d", pubTotal, len(pubList))
|
||||
}
|
||||
|
||||
// Delete.
|
||||
err = faves.Delete(fave.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("delete fave: %v", err)
|
||||
}
|
||||
_, err = faves.GetByID(fave.ID)
|
||||
if err != ErrFaveNotFound {
|
||||
t.Errorf("deleted fave error = %v, want ErrFaveNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListByTag(t *testing.T) {
|
||||
db := testDB(t)
|
||||
users := NewUserStore(db)
|
||||
faves := NewFaveStore(db)
|
||||
tags := NewTagStore(db)
|
||||
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() { Argon2Memory = 65536; Argon2Time = 3 }()
|
||||
|
||||
user, _ := users.Create("testuser", "password123", "user")
|
||||
|
||||
// Create two public faves with overlapping tags.
|
||||
f1, _ := faves.Create(user.ID, "Fave 1", "", "", "public")
|
||||
f2, _ := faves.Create(user.ID, "Fave 2", "", "", "public")
|
||||
f3, _ := faves.Create(user.ID, "Private Fave", "", "", "private")
|
||||
|
||||
tags.SetFaveTags(f1.ID, []string{"music", "jazz"})
|
||||
tags.SetFaveTags(f2.ID, []string{"music", "rock"})
|
||||
tags.SetFaveTags(f3.ID, []string{"music", "secret"})
|
||||
|
||||
// ListByTag only returns public faves.
|
||||
list, total, err := faves.ListByTag("music", 10, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("list by tag: %v", err)
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("total = %d, want 2 (private fave should be excluded)", total)
|
||||
}
|
||||
if len(list) != 2 {
|
||||
t.Errorf("len = %d, want 2", len(list))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFavePagination(t *testing.T) {
|
||||
db := testDB(t)
|
||||
users := NewUserStore(db)
|
||||
faves := NewFaveStore(db)
|
||||
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() { Argon2Memory = 65536; Argon2Time = 3 }()
|
||||
|
||||
user, _ := users.Create("testuser", "password123", "user")
|
||||
|
||||
// Create 5 faves.
|
||||
for i := 0; i < 5; i++ {
|
||||
faves.Create(user.ID, "Fave "+string(rune('A'+i)), "", "", "public")
|
||||
}
|
||||
|
||||
// Page 1 with limit 2.
|
||||
page1, total, err := faves.ListByUser(user.ID, 2, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("page 1: %v", err)
|
||||
}
|
||||
if total != 5 {
|
||||
t.Errorf("total = %d, want 5", total)
|
||||
}
|
||||
if len(page1) != 2 {
|
||||
t.Errorf("page 1 len = %d, want 2", len(page1))
|
||||
}
|
||||
|
||||
// Page 3 with limit 2 should have 1 item.
|
||||
page3, _, err := faves.ListByUser(user.ID, 2, 4)
|
||||
if err != nil {
|
||||
t.Fatalf("page 3: %v", err)
|
||||
}
|
||||
if len(page3) != 1 {
|
||||
t.Errorf("page 3 len = %d, want 1", len(page3))
|
||||
}
|
||||
}
|
||||
119
internal/store/session.go
Normal file
119
internal/store/session.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"kode.naiv.no/olemd/favoritter/internal/model"
|
||||
)
|
||||
|
||||
var ErrSessionNotFound = errors.New("session not found")
|
||||
|
||||
type SessionStore struct {
|
||||
db *sql.DB
|
||||
lifetime time.Duration
|
||||
}
|
||||
|
||||
func NewSessionStore(db *sql.DB) *SessionStore {
|
||||
return &SessionStore{db: db, lifetime: 720 * time.Hour} // default 30 days
|
||||
}
|
||||
|
||||
// SetLifetime configures the session lifetime.
|
||||
func (s *SessionStore) SetLifetime(d time.Duration) {
|
||||
s.lifetime = d
|
||||
}
|
||||
|
||||
// Create generates a new session token for the given user.
|
||||
func (s *SessionStore) Create(userID int64) (string, error) {
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return "", fmt.Errorf("generate session token: %w", err)
|
||||
}
|
||||
token := hex.EncodeToString(tokenBytes)
|
||||
|
||||
expiresAt := time.Now().UTC().Add(s.lifetime)
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)`,
|
||||
token, userID, expiresAt.Format(time.RFC3339),
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("insert session: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Validate checks if a session token is valid and not expired.
|
||||
// Returns the session if valid.
|
||||
func (s *SessionStore) Validate(token string) (*model.Session, error) {
|
||||
var session model.Session
|
||||
var expiresAt, createdAt string
|
||||
err := s.db.QueryRow(
|
||||
`SELECT token, user_id, expires_at, created_at
|
||||
FROM sessions WHERE token = ?`, token,
|
||||
).Scan(&session.Token, &session.UserID, &expiresAt, &createdAt)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query session: %w", err)
|
||||
}
|
||||
|
||||
session.ExpiresAt, _ = time.Parse(time.RFC3339, expiresAt)
|
||||
session.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
|
||||
if time.Now().UTC().After(session.ExpiresAt) {
|
||||
// Session has expired — delete it.
|
||||
s.Delete(token)
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// Delete removes a session by its token.
|
||||
func (s *SessionStore) Delete(token string) error {
|
||||
_, err := s.db.Exec("DELETE FROM sessions WHERE token = ?", token)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteAllForUser removes all sessions for a given user (e.g., on password change).
|
||||
func (s *SessionStore) DeleteAllForUser(userID int64) error {
|
||||
_, err := s.db.Exec("DELETE FROM sessions WHERE user_id = ?", userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// CleanupLoop periodically removes expired sessions. It runs until the
|
||||
// context is canceled.
|
||||
func (s *SessionStore) CleanupLoop(ctx context.Context, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
result, err := s.db.Exec(
|
||||
`DELETE FROM sessions WHERE expires_at < strftime('%Y-%m-%dT%H:%M:%SZ', 'now')`,
|
||||
)
|
||||
if err != nil {
|
||||
slog.Error("session cleanup failed", "error", err)
|
||||
continue
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
if n > 0 {
|
||||
slog.Info("cleaned up expired sessions", "count", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
46
internal/store/settings.go
Normal file
46
internal/store/settings.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"kode.naiv.no/olemd/favoritter/internal/model"
|
||||
)
|
||||
|
||||
type SettingsStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewSettingsStore(db *sql.DB) *SettingsStore {
|
||||
return &SettingsStore{db: db}
|
||||
}
|
||||
|
||||
// Get returns the current site settings.
|
||||
func (s *SettingsStore) Get() (*model.SiteSettings, error) {
|
||||
var settings model.SiteSettings
|
||||
var updatedAt string
|
||||
err := s.db.QueryRow(
|
||||
`SELECT site_name, site_description, signup_mode, updated_at
|
||||
FROM site_settings WHERE id = 1`,
|
||||
).Scan(&settings.SiteName, &settings.SiteDescription, &settings.SignupMode, &updatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query site settings: %w", err)
|
||||
}
|
||||
|
||||
settings.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
return &settings, nil
|
||||
}
|
||||
|
||||
// Update updates the site settings.
|
||||
func (s *SettingsStore) Update(siteName, siteDescription, signupMode string) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE site_settings SET site_name = ?, site_description = ?, signup_mode = ?,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
|
||||
WHERE id = 1`,
|
||||
siteName, siteDescription, signupMode,
|
||||
)
|
||||
return err
|
||||
}
|
||||
47
internal/store/signup_request.go
Normal file
47
internal/store/signup_request.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrSignupRequestExists = errors.New("signup request already exists")
|
||||
|
||||
type SignupRequestStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewSignupRequestStore(db *sql.DB) *SignupRequestStore {
|
||||
return &SignupRequestStore{db: db}
|
||||
}
|
||||
|
||||
// Create stores a pending signup request with a hashed password.
|
||||
func (s *SignupRequestStore) Create(username, password string) error {
|
||||
hash, err := hashPassword(password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(
|
||||
`INSERT INTO signup_requests (username, password_hash) VALUES (?, ?)`,
|
||||
username, hash,
|
||||
)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
return ErrSignupRequestExists
|
||||
}
|
||||
return fmt.Errorf("insert signup request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PendingCount returns the number of pending signup requests.
|
||||
func (s *SignupRequestStore) PendingCount() (int, error) {
|
||||
var n int
|
||||
err := s.db.QueryRow("SELECT COUNT(*) FROM signup_requests WHERE status = 'pending'").Scan(&n)
|
||||
return n, err
|
||||
}
|
||||
240
internal/store/tag.go
Normal file
240
internal/store/tag.go
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"kode.naiv.no/olemd/favoritter/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxTagsPerFave = 20
|
||||
MaxTagLength = 50
|
||||
)
|
||||
|
||||
var ErrTagNotFound = errors.New("tag not found")
|
||||
|
||||
type TagStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewTagStore(db *sql.DB) *TagStore {
|
||||
return &TagStore{db: db}
|
||||
}
|
||||
|
||||
// Search returns tags matching a prefix query, for autocomplete.
|
||||
// Results are ordered by how many faves use each tag (most popular first).
|
||||
func (s *TagStore) Search(query string, limit int) ([]model.Tag, error) {
|
||||
query = strings.TrimSpace(strings.ToLower(query))
|
||||
if query == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT t.id, t.name
|
||||
FROM tags t
|
||||
LEFT JOIN fave_tags ft ON ft.tag_id = t.id
|
||||
WHERE t.name LIKE ? || '%'
|
||||
GROUP BY t.id
|
||||
ORDER BY COUNT(ft.fave_id) DESC, t.name COLLATE NOCASE
|
||||
LIMIT ?`,
|
||||
query, limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search tags: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanTags(rows)
|
||||
}
|
||||
|
||||
func scanTags(rows *sql.Rows) ([]model.Tag, error) {
|
||||
var tags []model.Tag
|
||||
for rows.Next() {
|
||||
var t model.Tag
|
||||
if err := rows.Scan(&t.ID, &t.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tags = append(tags, t)
|
||||
}
|
||||
return tags, rows.Err()
|
||||
}
|
||||
|
||||
// GetOrCreate returns an existing tag by name, or creates it if it doesn't exist.
|
||||
// Tag names are normalized to lowercase and trimmed.
|
||||
func (s *TagStore) GetOrCreate(name string) (*model.Tag, error) {
|
||||
name = NormalizeTagName(name)
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("empty tag name")
|
||||
}
|
||||
if len(name) > MaxTagLength {
|
||||
return nil, fmt.Errorf("tag name too long (max %d characters)", MaxTagLength)
|
||||
}
|
||||
|
||||
// Try to find existing tag first (COLLATE NOCASE handles case).
|
||||
var tag model.Tag
|
||||
err := s.db.QueryRow("SELECT id, name FROM tags WHERE name = ?", name).Scan(&tag.ID, &tag.Name)
|
||||
if err == nil {
|
||||
return &tag, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create new tag.
|
||||
result, err := s.db.Exec("INSERT INTO tags (name) VALUES (?)", name)
|
||||
if err != nil {
|
||||
// Race condition: another request may have created it.
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
err2 := s.db.QueryRow("SELECT id, name FROM tags WHERE name = ?", name).Scan(&tag.ID, &tag.Name)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
return &tag, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tag.ID, _ = result.LastInsertId()
|
||||
tag.Name = name
|
||||
return &tag, nil
|
||||
}
|
||||
|
||||
// GetByID returns a tag by ID.
|
||||
func (s *TagStore) GetByID(id int64) (*model.Tag, error) {
|
||||
var tag model.Tag
|
||||
err := s.db.QueryRow("SELECT id, name FROM tags WHERE id = ?", id).Scan(&tag.ID, &tag.Name)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrTagNotFound
|
||||
}
|
||||
return &tag, err
|
||||
}
|
||||
|
||||
// GetByName returns a tag by its name (case-insensitive).
|
||||
func (s *TagStore) GetByName(name string) (*model.Tag, error) {
|
||||
var tag model.Tag
|
||||
err := s.db.QueryRow("SELECT id, name FROM tags WHERE name = ?", name).Scan(&tag.ID, &tag.Name)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrTagNotFound
|
||||
}
|
||||
return &tag, err
|
||||
}
|
||||
|
||||
// AttachToFave links a tag to a fave. No-op if already attached.
|
||||
func (s *TagStore) AttachToFave(faveID, tagID int64) error {
|
||||
_, err := s.db.Exec(
|
||||
"INSERT OR IGNORE INTO fave_tags (fave_id, tag_id) VALUES (?, ?)",
|
||||
faveID, tagID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// DetachFromFave removes a tag from a fave.
|
||||
func (s *TagStore) DetachFromFave(faveID, tagID int64) error {
|
||||
_, err := s.db.Exec(
|
||||
"DELETE FROM fave_tags WHERE fave_id = ? AND tag_id = ?",
|
||||
faveID, tagID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetFaveTags replaces all tags on a fave with the given tag names.
|
||||
// Creates new tags as needed. Enforces MaxTagsPerFave.
|
||||
func (s *TagStore) SetFaveTags(faveID int64, tagNames []string) error {
|
||||
if len(tagNames) > MaxTagsPerFave {
|
||||
tagNames = tagNames[:MaxTagsPerFave]
|
||||
}
|
||||
|
||||
// Remove all existing tags for this fave.
|
||||
if _, err := s.db.Exec("DELETE FROM fave_tags WHERE fave_id = ?", faveID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Attach new tags.
|
||||
for _, name := range tagNames {
|
||||
name = NormalizeTagName(name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
tag, err := s.GetOrCreate(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get or create tag %q: %w", name, err)
|
||||
}
|
||||
|
||||
if err := s.AttachToFave(faveID, tag.ID); err != nil {
|
||||
return fmt.Errorf("attach tag %q: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForFave returns all tags attached to a fave.
|
||||
func (s *TagStore) ForFave(faveID int64) ([]model.Tag, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT t.id, t.name FROM tags t
|
||||
JOIN fave_tags ft ON ft.tag_id = t.id
|
||||
WHERE ft.fave_id = ?
|
||||
ORDER BY t.name COLLATE NOCASE`,
|
||||
faveID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanTags(rows)
|
||||
}
|
||||
|
||||
// ListAll returns all tags ordered by name.
|
||||
func (s *TagStore) ListAll() ([]model.Tag, error) {
|
||||
rows, err := s.db.Query("SELECT id, name FROM tags ORDER BY name COLLATE NOCASE")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanTags(rows)
|
||||
}
|
||||
|
||||
// Rename changes a tag's name. Returns error if the new name already exists.
|
||||
func (s *TagStore) Rename(id int64, newName string) error {
|
||||
newName = NormalizeTagName(newName)
|
||||
if newName == "" {
|
||||
return fmt.Errorf("empty tag name")
|
||||
}
|
||||
_, err := s.db.Exec("UPDATE tags SET name = ? WHERE id = ?", newName, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete removes a tag and all its fave associations (via cascade).
|
||||
func (s *TagStore) Delete(id int64) error {
|
||||
// fave_tags rows are cleaned up by ON DELETE CASCADE.
|
||||
result, err := s.db.Exec("DELETE FROM tags WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
if n == 0 {
|
||||
return ErrTagNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupOrphans removes tags that are not attached to any faves.
|
||||
func (s *TagStore) CleanupOrphans() (int64, error) {
|
||||
result, err := s.db.Exec(
|
||||
`DELETE FROM tags WHERE id NOT IN (SELECT DISTINCT tag_id FROM fave_tags)`,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// NormalizeTagName lowercases and trims a tag name.
|
||||
func NormalizeTagName(name string) string {
|
||||
return strings.TrimSpace(strings.ToLower(name))
|
||||
}
|
||||
173
internal/store/tag_test.go
Normal file
173
internal/store/tag_test.go
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTagGetOrCreate(t *testing.T) {
|
||||
db := testDB(t)
|
||||
tags := NewTagStore(db)
|
||||
|
||||
// Create a new tag.
|
||||
tag1, err := tags.GetOrCreate("Film")
|
||||
if err != nil {
|
||||
t.Fatalf("get or create: %v", err)
|
||||
}
|
||||
if tag1.Name != "film" {
|
||||
t.Errorf("name = %q, want %q (should be normalized to lowercase)", tag1.Name, "film")
|
||||
}
|
||||
|
||||
// Get the same tag again (case-insensitive).
|
||||
tag2, err := tags.GetOrCreate("FILM")
|
||||
if err != nil {
|
||||
t.Fatalf("get or create duplicate: %v", err)
|
||||
}
|
||||
if tag2.ID != tag1.ID {
|
||||
t.Errorf("duplicate tag got different ID: %d vs %d", tag2.ID, tag1.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagSearch(t *testing.T) {
|
||||
db := testDB(t)
|
||||
tags := NewTagStore(db)
|
||||
users := NewUserStore(db)
|
||||
faves := NewFaveStore(db)
|
||||
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() { Argon2Memory = 65536; Argon2Time = 3 }()
|
||||
|
||||
user, _ := users.Create("testuser", "password123", "user")
|
||||
|
||||
// Create some tags via faves to give them usage counts.
|
||||
f1, _ := faves.Create(user.ID, "F1", "", "", "public")
|
||||
f2, _ := faves.Create(user.ID, "F2", "", "", "public")
|
||||
|
||||
tags.SetFaveTags(f1.ID, []string{"music", "movies", "misc"})
|
||||
tags.SetFaveTags(f2.ID, []string{"music", "manga"})
|
||||
|
||||
// Search for "mu" should return "music" first (2 faves) then nothing else.
|
||||
results, err := tags.Search("mu", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("search: %v", err)
|
||||
}
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("search results = %d, want 1", len(results))
|
||||
}
|
||||
if results[0].Name != "music" {
|
||||
t.Errorf("first result = %q, want %q", results[0].Name, "music")
|
||||
}
|
||||
|
||||
// Search for "m" should return music (2), manga (1), misc (1), movies (1).
|
||||
results, err = tags.Search("m", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("search: %v", err)
|
||||
}
|
||||
if len(results) != 4 {
|
||||
t.Errorf("search results = %d, want 4", len(results))
|
||||
}
|
||||
// Music should be first due to highest usage count.
|
||||
if results[0].Name != "music" {
|
||||
t.Errorf("first result = %q, want %q (most used)", results[0].Name, "music")
|
||||
}
|
||||
|
||||
// Empty search returns nothing.
|
||||
results, err = tags.Search("", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("empty search: %v", err)
|
||||
}
|
||||
if len(results) != 0 {
|
||||
t.Errorf("empty search results = %d, want 0", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagSetFaveTagsLimit(t *testing.T) {
|
||||
db := testDB(t)
|
||||
tags := NewTagStore(db)
|
||||
users := NewUserStore(db)
|
||||
faves := NewFaveStore(db)
|
||||
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() { Argon2Memory = 65536; Argon2Time = 3 }()
|
||||
|
||||
user, _ := users.Create("testuser", "password123", "user")
|
||||
fave, _ := faves.Create(user.ID, "Test", "", "", "public")
|
||||
|
||||
// Try to set more than MaxTagsPerFave tags.
|
||||
manyTags := make([]string, 30)
|
||||
for i := range manyTags {
|
||||
manyTags[i] = "tag" + string(rune('a'+i%26))
|
||||
}
|
||||
|
||||
err := tags.SetFaveTags(fave.ID, manyTags)
|
||||
if err != nil {
|
||||
t.Fatalf("set many tags: %v", err)
|
||||
}
|
||||
|
||||
attached, _ := tags.ForFave(fave.ID)
|
||||
if len(attached) > MaxTagsPerFave {
|
||||
t.Errorf("attached %d tags, max should be %d", len(attached), MaxTagsPerFave)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagCleanupOrphans(t *testing.T) {
|
||||
db := testDB(t)
|
||||
tags := NewTagStore(db)
|
||||
users := NewUserStore(db)
|
||||
faves := NewFaveStore(db)
|
||||
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() { Argon2Memory = 65536; Argon2Time = 3 }()
|
||||
|
||||
user, _ := users.Create("testuser", "password123", "user")
|
||||
fave, _ := faves.Create(user.ID, "Test", "", "", "public")
|
||||
|
||||
tags.SetFaveTags(fave.ID, []string{"keep", "orphan"})
|
||||
|
||||
// Remove the fave — "keep" and "orphan" are now orphaned.
|
||||
faves.Delete(fave.ID)
|
||||
|
||||
removed, err := tags.CleanupOrphans()
|
||||
if err != nil {
|
||||
t.Fatalf("cleanup: %v", err)
|
||||
}
|
||||
if removed != 2 {
|
||||
t.Errorf("removed = %d, want 2", removed)
|
||||
}
|
||||
|
||||
all, _ := tags.ListAll()
|
||||
if len(all) != 0 {
|
||||
t.Errorf("remaining tags = %d, want 0", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagRenameAndDelete(t *testing.T) {
|
||||
db := testDB(t)
|
||||
tags := NewTagStore(db)
|
||||
|
||||
tag, _ := tags.GetOrCreate("oldname")
|
||||
|
||||
err := tags.Rename(tag.ID, "NewName")
|
||||
if err != nil {
|
||||
t.Fatalf("rename: %v", err)
|
||||
}
|
||||
|
||||
renamed, _ := tags.GetByID(tag.ID)
|
||||
if renamed.Name != "newname" {
|
||||
t.Errorf("renamed = %q, want %q", renamed.Name, "newname")
|
||||
}
|
||||
|
||||
err = tags.Delete(tag.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("delete: %v", err)
|
||||
}
|
||||
|
||||
_, err = tags.GetByID(tag.ID)
|
||||
if err != ErrTagNotFound {
|
||||
t.Errorf("deleted tag error = %v, want ErrTagNotFound", err)
|
||||
}
|
||||
}
|
||||
326
internal/store/user.go
Normal file
326
internal/store/user.go
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
|
||||
"kode.naiv.no/olemd/favoritter/internal/model"
|
||||
)
|
||||
|
||||
// Argon2id parameters. Defaults match OWASP recommendations.
|
||||
var (
|
||||
Argon2Memory uint32 = 65536 // 64 MB
|
||||
Argon2Time uint32 = 3
|
||||
Argon2Parallelism uint8 = 2
|
||||
Argon2KeyLength uint32 = 32
|
||||
Argon2SaltLength = 16
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserExists = errors.New("username already taken")
|
||||
ErrUserDisabled = errors.New("user account is disabled")
|
||||
ErrInvalidCredentials = errors.New("invalid username or password")
|
||||
)
|
||||
|
||||
type UserStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewUserStore(db *sql.DB) *UserStore {
|
||||
return &UserStore{db: db}
|
||||
}
|
||||
|
||||
// Create creates a new user with the given username and plaintext password.
|
||||
func (s *UserStore) Create(username, password, role string) (*model.User, error) {
|
||||
hash, err := hashPassword(password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
result, err := s.db.Exec(
|
||||
`INSERT INTO users (username, password_hash, role) VALUES (?, ?, ?)`,
|
||||
username, hash, role,
|
||||
)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
return nil, ErrUserExists
|
||||
}
|
||||
return nil, fmt.Errorf("insert user: %w", err)
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
return s.GetByID(id)
|
||||
}
|
||||
|
||||
// CreateWithReset creates a new user that must reset their password on first login.
|
||||
func (s *UserStore) CreateWithReset(username, tempPassword, role string) (*model.User, error) {
|
||||
hash, err := hashPassword(tempPassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
result, err := s.db.Exec(
|
||||
`INSERT INTO users (username, password_hash, role, must_reset_password) VALUES (?, ?, ?, 1)`,
|
||||
username, hash, role,
|
||||
)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
return nil, ErrUserExists
|
||||
}
|
||||
return nil, fmt.Errorf("insert user: %w", err)
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
return s.GetByID(id)
|
||||
}
|
||||
|
||||
// Authenticate verifies credentials and returns the user if valid.
|
||||
func (s *UserStore) Authenticate(username, password string) (*model.User, error) {
|
||||
user, err := s.GetByUsername(username)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// Still do a dummy hash comparison to prevent timing attacks.
|
||||
dummyHash(password)
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.Disabled {
|
||||
return nil, ErrUserDisabled
|
||||
}
|
||||
|
||||
if !verifyPassword(password, user.PasswordHash) {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a user by their ID.
|
||||
func (s *UserStore) GetByID(id int64) (*model.User, error) {
|
||||
return scanUserFrom(s.db.QueryRow(
|
||||
`SELECT id, username, display_name, bio, avatar_path, password_hash,
|
||||
role, profile_visibility, default_fave_privacy,
|
||||
must_reset_password, disabled, created_at, updated_at
|
||||
FROM users WHERE id = ?`, id,
|
||||
))
|
||||
}
|
||||
|
||||
// GetByUsername retrieves a user by their username (case-insensitive).
|
||||
func (s *UserStore) GetByUsername(username string) (*model.User, error) {
|
||||
return scanUserFrom(s.db.QueryRow(
|
||||
`SELECT id, username, display_name, bio, avatar_path, password_hash,
|
||||
role, profile_visibility, default_fave_privacy,
|
||||
must_reset_password, disabled, created_at, updated_at
|
||||
FROM users WHERE username = ?`, username,
|
||||
))
|
||||
}
|
||||
|
||||
// UpdatePassword changes a user's password and clears the must_reset_password flag.
|
||||
func (s *UserStore) UpdatePassword(userID int64, newPassword string) error {
|
||||
hash, err := hashPassword(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(
|
||||
`UPDATE users SET password_hash = ?, must_reset_password = 0,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
|
||||
WHERE id = ?`,
|
||||
hash, userID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateProfile updates a user's profile fields.
|
||||
func (s *UserStore) UpdateProfile(userID int64, displayName, bio, profileVisibility, defaultFavePrivacy string) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE users SET display_name = ?, bio = ?, profile_visibility = ?,
|
||||
default_fave_privacy = ?,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
|
||||
WHERE id = ?`,
|
||||
displayName, bio, profileVisibility, defaultFavePrivacy, userID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateAvatar updates a user's avatar path.
|
||||
func (s *UserStore) UpdateAvatar(userID int64, avatarPath string) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE users SET avatar_path = ?,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
|
||||
WHERE id = ?`,
|
||||
avatarPath, userID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetDisabled enables or disables a user account.
|
||||
func (s *UserStore) SetDisabled(userID int64, disabled bool) error {
|
||||
val := 0
|
||||
if disabled {
|
||||
val = 1
|
||||
}
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE users SET disabled = ?,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
|
||||
WHERE id = ?`,
|
||||
val, userID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListAll returns all users, ordered by username.
|
||||
func (s *UserStore) ListAll() ([]*model.User, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, username, display_name, bio, avatar_path, password_hash,
|
||||
role, profile_visibility, default_fave_privacy,
|
||||
must_reset_password, disabled, created_at, updated_at
|
||||
FROM users ORDER BY username COLLATE NOCASE`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*model.User
|
||||
for rows.Next() {
|
||||
u, err := scanUserFrom(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, u)
|
||||
}
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
// Count returns the total number of users.
|
||||
func (s *UserStore) Count() (int, error) {
|
||||
var n int
|
||||
err := s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// EnsureAdmin creates the initial admin user if no users exist yet.
|
||||
// This is called on startup with the configured admin credentials.
|
||||
func (s *UserStore) EnsureAdmin(username, password string) error {
|
||||
if username == "" || password == "" {
|
||||
// No admin credentials configured — only skip if users already exist.
|
||||
count, err := s.Count()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
slog.Warn("no admin credentials configured and no users exist — set FAVORITTER_ADMIN_USERNAME and FAVORITTER_ADMIN_PASSWORD")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if this admin already exists.
|
||||
_, err := s.GetByUsername(username)
|
||||
if err == nil {
|
||||
return nil // Already exists.
|
||||
}
|
||||
if !errors.Is(err, ErrUserNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.Create(username, password, "admin")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create admin user: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("created initial admin user", "username", username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanner is implemented by both *sql.Row and *sql.Rows.
|
||||
type scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanUserFrom(s scanner) (*model.User, error) {
|
||||
u := &model.User{}
|
||||
var createdAt, updatedAt string
|
||||
err := s.Scan(
|
||||
&u.ID, &u.Username, &u.DisplayName, &u.Bio, &u.AvatarPath,
|
||||
&u.PasswordHash, &u.Role, &u.ProfileVisibility, &u.DefaultFavePrivacy,
|
||||
&u.MustResetPassword, &u.Disabled, &createdAt, &updatedAt,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan user: %w", err)
|
||||
}
|
||||
u.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
u.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Password hashing with Argon2id.
|
||||
// Format: $argon2id$v=19$m=65536,t=3,p=2$<salt>$<hash>
|
||||
|
||||
func hashPassword(password string) (string, error) {
|
||||
salt := make([]byte, Argon2SaltLength)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", fmt.Errorf("generate salt: %w", err)
|
||||
}
|
||||
|
||||
hash := argon2.IDKey([]byte(password), salt, Argon2Time, Argon2Memory, Argon2Parallelism, Argon2KeyLength)
|
||||
|
||||
return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version,
|
||||
Argon2Memory, Argon2Time, Argon2Parallelism,
|
||||
base64.RawStdEncoding.EncodeToString(salt),
|
||||
base64.RawStdEncoding.EncodeToString(hash),
|
||||
), nil
|
||||
}
|
||||
|
||||
func verifyPassword(password, encodedHash string) bool {
|
||||
parts := strings.Split(encodedHash, "$")
|
||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||
return false
|
||||
}
|
||||
|
||||
var memory uint32
|
||||
var iterations uint32
|
||||
var parallelism uint8
|
||||
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &iterations, ¶llelism)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
hash := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, uint32(len(expectedHash)))
|
||||
|
||||
return subtle.ConstantTimeCompare(hash, expectedHash) == 1
|
||||
}
|
||||
|
||||
// dummyHash performs a hash operation to prevent timing-based username enumeration.
|
||||
func dummyHash(password string) {
|
||||
salt := make([]byte, Argon2SaltLength)
|
||||
argon2.IDKey([]byte(password), salt, Argon2Time, Argon2Memory, Argon2Parallelism, Argon2KeyLength)
|
||||
}
|
||||
205
internal/store/user_test.go
Normal file
205
internal/store/user_test.go
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"kode.naiv.no/olemd/favoritter/internal/database"
|
||||
)
|
||||
|
||||
func testDB(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
db, err := database.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("open test db: %v", err)
|
||||
}
|
||||
if err := database.Migrate(db); err != nil {
|
||||
t.Fatalf("migrate test db: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { db.Close() })
|
||||
return db
|
||||
}
|
||||
|
||||
func TestCreateAndAuthenticate(t *testing.T) {
|
||||
db := testDB(t)
|
||||
users := NewUserStore(db)
|
||||
|
||||
// Use fast Argon2 parameters for tests.
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() {
|
||||
Argon2Memory = 65536
|
||||
Argon2Time = 3
|
||||
}()
|
||||
|
||||
// Create a user.
|
||||
user, err := users.Create("testuser", "password123", "user")
|
||||
if err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
if user.Username != "testuser" {
|
||||
t.Errorf("username = %q, want %q", user.Username, "testuser")
|
||||
}
|
||||
if user.Role != "user" {
|
||||
t.Errorf("role = %q, want %q", user.Role, "user")
|
||||
}
|
||||
|
||||
// Authenticate with correct password.
|
||||
authed, err := users.Authenticate("testuser", "password123")
|
||||
if err != nil {
|
||||
t.Fatalf("authenticate: %v", err)
|
||||
}
|
||||
if authed.ID != user.ID {
|
||||
t.Errorf("authenticated user ID = %d, want %d", authed.ID, user.ID)
|
||||
}
|
||||
|
||||
// Authenticate with wrong password.
|
||||
_, err = users.Authenticate("testuser", "wrongpassword")
|
||||
if err != ErrInvalidCredentials {
|
||||
t.Errorf("wrong password error = %v, want ErrInvalidCredentials", err)
|
||||
}
|
||||
|
||||
// Authenticate with non-existent user.
|
||||
_, err = users.Authenticate("nouser", "password123")
|
||||
if err != ErrInvalidCredentials {
|
||||
t.Errorf("non-existent user error = %v, want ErrInvalidCredentials", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDuplicate(t *testing.T) {
|
||||
db := testDB(t)
|
||||
users := NewUserStore(db)
|
||||
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() {
|
||||
Argon2Memory = 65536
|
||||
Argon2Time = 3
|
||||
}()
|
||||
|
||||
_, err := users.Create("testuser", "password123", "user")
|
||||
if err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
|
||||
_, err = users.Create("testuser", "password456", "user")
|
||||
if err != ErrUserExists {
|
||||
t.Errorf("duplicate error = %v, want ErrUserExists", err)
|
||||
}
|
||||
|
||||
// Case-insensitive duplicate.
|
||||
_, err = users.Create("TestUser", "password456", "user")
|
||||
if err != ErrUserExists {
|
||||
t.Errorf("case-insensitive duplicate error = %v, want ErrUserExists", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdatePassword(t *testing.T) {
|
||||
db := testDB(t)
|
||||
users := NewUserStore(db)
|
||||
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() {
|
||||
Argon2Memory = 65536
|
||||
Argon2Time = 3
|
||||
}()
|
||||
|
||||
user, err := users.CreateWithReset("admin", "temppass", "admin")
|
||||
if err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
if !user.MustResetPassword {
|
||||
t.Error("expected must_reset_password to be true")
|
||||
}
|
||||
|
||||
err = users.UpdatePassword(user.ID, "newpassword123")
|
||||
if err != nil {
|
||||
t.Fatalf("update password: %v", err)
|
||||
}
|
||||
|
||||
// Verify old password no longer works.
|
||||
_, err = users.Authenticate("admin", "temppass")
|
||||
if err != ErrInvalidCredentials {
|
||||
t.Error("old password should not work after reset")
|
||||
}
|
||||
|
||||
// Verify new password works and reset flag is cleared.
|
||||
updated, err := users.Authenticate("admin", "newpassword123")
|
||||
if err != nil {
|
||||
t.Fatalf("authenticate with new password: %v", err)
|
||||
}
|
||||
if updated.MustResetPassword {
|
||||
t.Error("must_reset_password should be false after update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAdmin(t *testing.T) {
|
||||
db := testDB(t)
|
||||
users := NewUserStore(db)
|
||||
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() {
|
||||
Argon2Memory = 65536
|
||||
Argon2Time = 3
|
||||
}()
|
||||
|
||||
// First call creates the admin.
|
||||
err := users.EnsureAdmin("admin", "adminpass")
|
||||
if err != nil {
|
||||
t.Fatalf("ensure admin: %v", err)
|
||||
}
|
||||
|
||||
admin, err := users.GetByUsername("admin")
|
||||
if err != nil {
|
||||
t.Fatalf("get admin: %v", err)
|
||||
}
|
||||
if !admin.IsAdmin() {
|
||||
t.Error("expected admin role")
|
||||
}
|
||||
|
||||
// Second call is a no-op.
|
||||
err = users.EnsureAdmin("admin", "adminpass")
|
||||
if err != nil {
|
||||
t.Fatalf("ensure admin (second call): %v", err)
|
||||
}
|
||||
|
||||
count, _ := users.Count()
|
||||
if count != 1 {
|
||||
t.Errorf("user count = %d, want 1", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisabledUser(t *testing.T) {
|
||||
db := testDB(t)
|
||||
users := NewUserStore(db)
|
||||
|
||||
Argon2Memory = 1024
|
||||
Argon2Time = 1
|
||||
defer func() {
|
||||
Argon2Memory = 65536
|
||||
Argon2Time = 3
|
||||
}()
|
||||
|
||||
user, err := users.Create("testuser", "password123", "user")
|
||||
if err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
|
||||
// Disable the user.
|
||||
err = users.SetDisabled(user.ID, true)
|
||||
if err != nil {
|
||||
t.Fatalf("disable user: %v", err)
|
||||
}
|
||||
|
||||
// Authentication should fail.
|
||||
_, err = users.Authenticate("testuser", "password123")
|
||||
if err != ErrUserDisabled {
|
||||
t.Errorf("disabled user error = %v, want ErrUserDisabled", err)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue