// SPDX-License-Identifier: AGPL-3.0-or-later package store import ( "crypto/rand" "crypto/subtle" "database/sql" "encoding/base64" "errors" "fmt" "log/slog" "strings" "time" "golang.org/x/crypto/argon2" "kode.naiv.no/olemd/favoritter/internal/model" ) // Argon2id parameters. Defaults match OWASP recommendations. var ( Argon2Memory uint32 = 65536 // 64 MB Argon2Time uint32 = 3 Argon2Parallelism uint8 = 2 Argon2KeyLength uint32 = 32 Argon2SaltLength = 16 ) var ( ErrUserNotFound = errors.New("user not found") ErrUserExists = errors.New("username already taken") ErrUserDisabled = errors.New("user account is disabled") ErrInvalidCredentials = errors.New("invalid username or password") ) type UserStore struct { db *sql.DB } func NewUserStore(db *sql.DB) *UserStore { return &UserStore{db: db} } // Create creates a new user with the given username and plaintext password. func (s *UserStore) Create(username, password, role string) (*model.User, error) { hash, err := hashPassword(password) if err != nil { return nil, fmt.Errorf("hash password: %w", err) } result, err := s.db.Exec( `INSERT INTO users (username, password_hash, role) VALUES (?, ?, ?)`, username, hash, role, ) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { return nil, ErrUserExists } return nil, fmt.Errorf("insert user: %w", err) } id, _ := result.LastInsertId() return s.GetByID(id) } // CreateWithReset creates a new user that must reset their password on first login. func (s *UserStore) CreateWithReset(username, tempPassword, role string) (*model.User, error) { hash, err := hashPassword(tempPassword) if err != nil { return nil, fmt.Errorf("hash password: %w", err) } result, err := s.db.Exec( `INSERT INTO users (username, password_hash, role, must_reset_password) VALUES (?, ?, ?, 1)`, username, hash, role, ) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { return nil, ErrUserExists } return nil, fmt.Errorf("insert user: %w", err) } id, _ := result.LastInsertId() return s.GetByID(id) } // Authenticate verifies credentials and returns the user if valid. func (s *UserStore) Authenticate(username, password string) (*model.User, error) { user, err := s.GetByUsername(username) if err != nil { if errors.Is(err, ErrUserNotFound) { // Still do a dummy hash comparison to prevent timing attacks. dummyHash(password) return nil, ErrInvalidCredentials } return nil, err } if user.Disabled { return nil, ErrUserDisabled } if !verifyPassword(password, user.PasswordHash) { return nil, ErrInvalidCredentials } return user, nil } // GetByID retrieves a user by their ID. func (s *UserStore) GetByID(id int64) (*model.User, error) { return scanUserFrom(s.db.QueryRow( `SELECT id, username, display_name, bio, avatar_path, password_hash, role, profile_visibility, default_fave_privacy, must_reset_password, disabled, created_at, updated_at FROM users WHERE id = ?`, id, )) } // GetByUsername retrieves a user by their username (case-insensitive). func (s *UserStore) GetByUsername(username string) (*model.User, error) { return scanUserFrom(s.db.QueryRow( `SELECT id, username, display_name, bio, avatar_path, password_hash, role, profile_visibility, default_fave_privacy, must_reset_password, disabled, created_at, updated_at FROM users WHERE username = ?`, username, )) } // UpdatePassword changes a user's password and clears the must_reset_password flag. func (s *UserStore) UpdatePassword(userID int64, newPassword string) error { hash, err := hashPassword(newPassword) if err != nil { return fmt.Errorf("hash password: %w", err) } _, err = s.db.Exec( `UPDATE users SET password_hash = ?, must_reset_password = 0, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') WHERE id = ?`, hash, userID, ) return err } // UpdateProfile updates a user's profile fields. func (s *UserStore) UpdateProfile(userID int64, displayName, bio, profileVisibility, defaultFavePrivacy string) error { _, err := s.db.Exec( `UPDATE users SET display_name = ?, bio = ?, profile_visibility = ?, default_fave_privacy = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') WHERE id = ?`, displayName, bio, profileVisibility, defaultFavePrivacy, userID, ) return err } // SetMustResetPassword sets or clears the must_reset_password flag. func (s *UserStore) SetMustResetPassword(userID int64, must bool) error { val := 0 if must { val = 1 } _, err := s.db.Exec( `UPDATE users SET must_reset_password = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') WHERE id = ?`, val, userID, ) return err } // UpdateAvatar updates a user's avatar path. func (s *UserStore) UpdateAvatar(userID int64, avatarPath string) error { _, err := s.db.Exec( `UPDATE users SET avatar_path = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') WHERE id = ?`, avatarPath, userID, ) return err } // SetDisabled enables or disables a user account. func (s *UserStore) SetDisabled(userID int64, disabled bool) error { val := 0 if disabled { val = 1 } _, err := s.db.Exec( `UPDATE users SET disabled = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') WHERE id = ?`, val, userID, ) return err } // ListAll returns all users, ordered by username. func (s *UserStore) ListAll() ([]*model.User, error) { rows, err := s.db.Query( `SELECT id, username, display_name, bio, avatar_path, password_hash, role, profile_visibility, default_fave_privacy, must_reset_password, disabled, created_at, updated_at FROM users ORDER BY username COLLATE NOCASE`, ) if err != nil { return nil, err } defer rows.Close() var users []*model.User for rows.Next() { u, err := scanUserFrom(rows) if err != nil { return nil, err } users = append(users, u) } return users, rows.Err() } // Count returns the total number of users. func (s *UserStore) Count() (int, error) { var n int err := s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&n) return n, err } // EnsureAdmin creates the initial admin user if no users exist yet. // This is called on startup with the configured admin credentials. func (s *UserStore) EnsureAdmin(username, password string) error { if username == "" || password == "" { // No admin credentials configured — only skip if users already exist. count, err := s.Count() if err != nil { return err } if count == 0 { slog.Warn("no admin credentials configured and no users exist — set FAVORITTER_ADMIN_USERNAME and FAVORITTER_ADMIN_PASSWORD") } return nil } // Check if this admin already exists. _, err := s.GetByUsername(username) if err == nil { return nil // Already exists. } if !errors.Is(err, ErrUserNotFound) { return err } _, err = s.Create(username, password, "admin") if err != nil { return fmt.Errorf("create admin user: %w", err) } slog.Info("created initial admin user", "username", username) return nil } // scanner is implemented by both *sql.Row and *sql.Rows. type scanner interface { Scan(dest ...any) error } func scanUserFrom(s scanner) (*model.User, error) { u := &model.User{} var createdAt, updatedAt string err := s.Scan( &u.ID, &u.Username, &u.DisplayName, &u.Bio, &u.AvatarPath, &u.PasswordHash, &u.Role, &u.ProfileVisibility, &u.DefaultFavePrivacy, &u.MustResetPassword, &u.Disabled, &createdAt, &updatedAt, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrUserNotFound } if err != nil { return nil, fmt.Errorf("scan user: %w", err) } u.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) u.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) return u, nil } // Password hashing with Argon2id. // Format: $argon2id$v=19$m=65536,t=3,p=2$$ func hashPassword(password string) (string, error) { salt := make([]byte, Argon2SaltLength) if _, err := rand.Read(salt); err != nil { return "", fmt.Errorf("generate salt: %w", err) } hash := argon2.IDKey([]byte(password), salt, Argon2Time, Argon2Memory, Argon2Parallelism, Argon2KeyLength) return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, Argon2Memory, Argon2Time, Argon2Parallelism, base64.RawStdEncoding.EncodeToString(salt), base64.RawStdEncoding.EncodeToString(hash), ), nil } func verifyPassword(password, encodedHash string) bool { parts := strings.Split(encodedHash, "$") if len(parts) != 6 || parts[1] != "argon2id" { return false } var memory uint32 var iterations uint32 var parallelism uint8 _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &iterations, ¶llelism) if err != nil { return false } salt, err := base64.RawStdEncoding.DecodeString(parts[4]) if err != nil { return false } expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) if err != nil { return false } hash := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, uint32(len(expectedHash))) return subtle.ConstantTimeCompare(hash, expectedHash) == 1 } // dummyHash performs a hash operation to prevent timing-based username enumeration. func dummyHash(password string) { salt := make([]byte, Argon2SaltLength) argon2.IDKey([]byte(password), salt, Argon2Time, Argon2Memory, Argon2Parallelism, Argon2KeyLength) }