326 lines
8.8 KiB
Go
326 lines
8.8 KiB
Go
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
|
|
||
|
|
package store
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/rand"
|
||
|
|
"crypto/subtle"
|
||
|
|
"database/sql"
|
||
|
|
"encoding/base64"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"log/slog"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"golang.org/x/crypto/argon2"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/favoritter/internal/model"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Argon2id parameters. Defaults match OWASP recommendations.
|
||
|
|
var (
|
||
|
|
Argon2Memory uint32 = 65536 // 64 MB
|
||
|
|
Argon2Time uint32 = 3
|
||
|
|
Argon2Parallelism uint8 = 2
|
||
|
|
Argon2KeyLength uint32 = 32
|
||
|
|
Argon2SaltLength = 16
|
||
|
|
)
|
||
|
|
|
||
|
|
var (
|
||
|
|
ErrUserNotFound = errors.New("user not found")
|
||
|
|
ErrUserExists = errors.New("username already taken")
|
||
|
|
ErrUserDisabled = errors.New("user account is disabled")
|
||
|
|
ErrInvalidCredentials = errors.New("invalid username or password")
|
||
|
|
)
|
||
|
|
|
||
|
|
type UserStore struct {
|
||
|
|
db *sql.DB
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewUserStore(db *sql.DB) *UserStore {
|
||
|
|
return &UserStore{db: db}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create creates a new user with the given username and plaintext password.
|
||
|
|
func (s *UserStore) Create(username, password, role string) (*model.User, error) {
|
||
|
|
hash, err := hashPassword(password)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("hash password: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
result, err := s.db.Exec(
|
||
|
|
`INSERT INTO users (username, password_hash, role) VALUES (?, ?, ?)`,
|
||
|
|
username, hash, role,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||
|
|
return nil, ErrUserExists
|
||
|
|
}
|
||
|
|
return nil, fmt.Errorf("insert user: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
id, _ := result.LastInsertId()
|
||
|
|
return s.GetByID(id)
|
||
|
|
}
|
||
|
|
|
||
|
|
// CreateWithReset creates a new user that must reset their password on first login.
|
||
|
|
func (s *UserStore) CreateWithReset(username, tempPassword, role string) (*model.User, error) {
|
||
|
|
hash, err := hashPassword(tempPassword)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("hash password: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
result, err := s.db.Exec(
|
||
|
|
`INSERT INTO users (username, password_hash, role, must_reset_password) VALUES (?, ?, ?, 1)`,
|
||
|
|
username, hash, role,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||
|
|
return nil, ErrUserExists
|
||
|
|
}
|
||
|
|
return nil, fmt.Errorf("insert user: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
id, _ := result.LastInsertId()
|
||
|
|
return s.GetByID(id)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Authenticate verifies credentials and returns the user if valid.
|
||
|
|
func (s *UserStore) Authenticate(username, password string) (*model.User, error) {
|
||
|
|
user, err := s.GetByUsername(username)
|
||
|
|
if err != nil {
|
||
|
|
if errors.Is(err, ErrUserNotFound) {
|
||
|
|
// Still do a dummy hash comparison to prevent timing attacks.
|
||
|
|
dummyHash(password)
|
||
|
|
return nil, ErrInvalidCredentials
|
||
|
|
}
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
if user.Disabled {
|
||
|
|
return nil, ErrUserDisabled
|
||
|
|
}
|
||
|
|
|
||
|
|
if !verifyPassword(password, user.PasswordHash) {
|
||
|
|
return nil, ErrInvalidCredentials
|
||
|
|
}
|
||
|
|
|
||
|
|
return user, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetByID retrieves a user by their ID.
|
||
|
|
func (s *UserStore) GetByID(id int64) (*model.User, error) {
|
||
|
|
return scanUserFrom(s.db.QueryRow(
|
||
|
|
`SELECT id, username, display_name, bio, avatar_path, password_hash,
|
||
|
|
role, profile_visibility, default_fave_privacy,
|
||
|
|
must_reset_password, disabled, created_at, updated_at
|
||
|
|
FROM users WHERE id = ?`, id,
|
||
|
|
))
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetByUsername retrieves a user by their username (case-insensitive).
|
||
|
|
func (s *UserStore) GetByUsername(username string) (*model.User, error) {
|
||
|
|
return scanUserFrom(s.db.QueryRow(
|
||
|
|
`SELECT id, username, display_name, bio, avatar_path, password_hash,
|
||
|
|
role, profile_visibility, default_fave_privacy,
|
||
|
|
must_reset_password, disabled, created_at, updated_at
|
||
|
|
FROM users WHERE username = ?`, username,
|
||
|
|
))
|
||
|
|
}
|
||
|
|
|
||
|
|
// UpdatePassword changes a user's password and clears the must_reset_password flag.
|
||
|
|
func (s *UserStore) UpdatePassword(userID int64, newPassword string) error {
|
||
|
|
hash, err := hashPassword(newPassword)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("hash password: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
_, err = s.db.Exec(
|
||
|
|
`UPDATE users SET password_hash = ?, must_reset_password = 0,
|
||
|
|
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
|
||
|
|
WHERE id = ?`,
|
||
|
|
hash, userID,
|
||
|
|
)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// UpdateProfile updates a user's profile fields.
|
||
|
|
func (s *UserStore) UpdateProfile(userID int64, displayName, bio, profileVisibility, defaultFavePrivacy string) error {
|
||
|
|
_, err := s.db.Exec(
|
||
|
|
`UPDATE users SET display_name = ?, bio = ?, profile_visibility = ?,
|
||
|
|
default_fave_privacy = ?,
|
||
|
|
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
|
||
|
|
WHERE id = ?`,
|
||
|
|
displayName, bio, profileVisibility, defaultFavePrivacy, userID,
|
||
|
|
)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// UpdateAvatar updates a user's avatar path.
|
||
|
|
func (s *UserStore) UpdateAvatar(userID int64, avatarPath string) error {
|
||
|
|
_, err := s.db.Exec(
|
||
|
|
`UPDATE users SET avatar_path = ?,
|
||
|
|
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
|
||
|
|
WHERE id = ?`,
|
||
|
|
avatarPath, userID,
|
||
|
|
)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// SetDisabled enables or disables a user account.
|
||
|
|
func (s *UserStore) SetDisabled(userID int64, disabled bool) error {
|
||
|
|
val := 0
|
||
|
|
if disabled {
|
||
|
|
val = 1
|
||
|
|
}
|
||
|
|
_, err := s.db.Exec(
|
||
|
|
`UPDATE users SET disabled = ?,
|
||
|
|
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
|
||
|
|
WHERE id = ?`,
|
||
|
|
val, userID,
|
||
|
|
)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// ListAll returns all users, ordered by username.
|
||
|
|
func (s *UserStore) ListAll() ([]*model.User, error) {
|
||
|
|
rows, err := s.db.Query(
|
||
|
|
`SELECT id, username, display_name, bio, avatar_path, password_hash,
|
||
|
|
role, profile_visibility, default_fave_privacy,
|
||
|
|
must_reset_password, disabled, created_at, updated_at
|
||
|
|
FROM users ORDER BY username COLLATE NOCASE`,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
var users []*model.User
|
||
|
|
for rows.Next() {
|
||
|
|
u, err := scanUserFrom(rows)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
users = append(users, u)
|
||
|
|
}
|
||
|
|
return users, rows.Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
// Count returns the total number of users.
|
||
|
|
func (s *UserStore) Count() (int, error) {
|
||
|
|
var n int
|
||
|
|
err := s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&n)
|
||
|
|
return n, err
|
||
|
|
}
|
||
|
|
|
||
|
|
// EnsureAdmin creates the initial admin user if no users exist yet.
|
||
|
|
// This is called on startup with the configured admin credentials.
|
||
|
|
func (s *UserStore) EnsureAdmin(username, password string) error {
|
||
|
|
if username == "" || password == "" {
|
||
|
|
// No admin credentials configured — only skip if users already exist.
|
||
|
|
count, err := s.Count()
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if count == 0 {
|
||
|
|
slog.Warn("no admin credentials configured and no users exist — set FAVORITTER_ADMIN_USERNAME and FAVORITTER_ADMIN_PASSWORD")
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check if this admin already exists.
|
||
|
|
_, err := s.GetByUsername(username)
|
||
|
|
if err == nil {
|
||
|
|
return nil // Already exists.
|
||
|
|
}
|
||
|
|
if !errors.Is(err, ErrUserNotFound) {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
_, err = s.Create(username, password, "admin")
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("create admin user: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
slog.Info("created initial admin user", "username", username)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// scanner is implemented by both *sql.Row and *sql.Rows.
|
||
|
|
type scanner interface {
|
||
|
|
Scan(dest ...any) error
|
||
|
|
}
|
||
|
|
|
||
|
|
func scanUserFrom(s scanner) (*model.User, error) {
|
||
|
|
u := &model.User{}
|
||
|
|
var createdAt, updatedAt string
|
||
|
|
err := s.Scan(
|
||
|
|
&u.ID, &u.Username, &u.DisplayName, &u.Bio, &u.AvatarPath,
|
||
|
|
&u.PasswordHash, &u.Role, &u.ProfileVisibility, &u.DefaultFavePrivacy,
|
||
|
|
&u.MustResetPassword, &u.Disabled, &createdAt, &updatedAt,
|
||
|
|
)
|
||
|
|
if errors.Is(err, sql.ErrNoRows) {
|
||
|
|
return nil, ErrUserNotFound
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("scan user: %w", err)
|
||
|
|
}
|
||
|
|
u.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||
|
|
u.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||
|
|
return u, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Password hashing with Argon2id.
|
||
|
|
// Format: $argon2id$v=19$m=65536,t=3,p=2$<salt>$<hash>
|
||
|
|
|
||
|
|
func hashPassword(password string) (string, error) {
|
||
|
|
salt := make([]byte, Argon2SaltLength)
|
||
|
|
if _, err := rand.Read(salt); err != nil {
|
||
|
|
return "", fmt.Errorf("generate salt: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
hash := argon2.IDKey([]byte(password), salt, Argon2Time, Argon2Memory, Argon2Parallelism, Argon2KeyLength)
|
||
|
|
|
||
|
|
return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||
|
|
argon2.Version,
|
||
|
|
Argon2Memory, Argon2Time, Argon2Parallelism,
|
||
|
|
base64.RawStdEncoding.EncodeToString(salt),
|
||
|
|
base64.RawStdEncoding.EncodeToString(hash),
|
||
|
|
), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func verifyPassword(password, encodedHash string) bool {
|
||
|
|
parts := strings.Split(encodedHash, "$")
|
||
|
|
if len(parts) != 6 || parts[1] != "argon2id" {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
var memory uint32
|
||
|
|
var iterations uint32
|
||
|
|
var parallelism uint8
|
||
|
|
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &iterations, ¶llelism)
|
||
|
|
if err != nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||
|
|
if err != nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||
|
|
if err != nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
hash := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, uint32(len(expectedHash)))
|
||
|
|
|
||
|
|
return subtle.ConstantTimeCompare(hash, expectedHash) == 1
|
||
|
|
}
|
||
|
|
|
||
|
|
// dummyHash performs a hash operation to prevent timing-based username enumeration.
|
||
|
|
func dummyHash(password string) {
|
||
|
|
salt := make([]byte, Argon2SaltLength)
|
||
|
|
argon2.IDKey([]byte(password), salt, Argon2Time, Argon2Memory, Argon2Parallelism, Argon2KeyLength)
|
||
|
|
}
|