// SPDX-License-Identifier: AGPL-3.0-or-later package store import ( "context" "crypto/rand" "database/sql" "encoding/hex" "errors" "fmt" "log/slog" "time" "kode.naiv.no/olemd/favoritter/internal/model" ) var ErrSessionNotFound = errors.New("session not found") type SessionStore struct { db *sql.DB lifetime time.Duration } func NewSessionStore(db *sql.DB) *SessionStore { return &SessionStore{db: db, lifetime: 720 * time.Hour} // default 30 days } // SetLifetime configures the session lifetime. func (s *SessionStore) SetLifetime(d time.Duration) { s.lifetime = d } // Create generates a new session token for the given user. func (s *SessionStore) Create(userID int64) (string, error) { tokenBytes := make([]byte, 32) if _, err := rand.Read(tokenBytes); err != nil { return "", fmt.Errorf("generate session token: %w", err) } token := hex.EncodeToString(tokenBytes) expiresAt := time.Now().UTC().Add(s.lifetime) _, err := s.db.Exec( `INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)`, token, userID, expiresAt.Format(time.RFC3339), ) if err != nil { return "", fmt.Errorf("insert session: %w", err) } return token, nil } // Validate checks if a session token is valid and not expired. // Returns the session if valid. func (s *SessionStore) Validate(token string) (*model.Session, error) { var session model.Session var expiresAt, createdAt string err := s.db.QueryRow( `SELECT token, user_id, expires_at, created_at FROM sessions WHERE token = ?`, token, ).Scan(&session.Token, &session.UserID, &expiresAt, &createdAt) if errors.Is(err, sql.ErrNoRows) { return nil, ErrSessionNotFound } if err != nil { return nil, fmt.Errorf("query session: %w", err) } session.ExpiresAt, _ = time.Parse(time.RFC3339, expiresAt) session.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) if time.Now().UTC().After(session.ExpiresAt) { // Session has expired — delete it. s.Delete(token) return nil, ErrSessionNotFound } return &session, nil } // Delete removes a session by its token. func (s *SessionStore) Delete(token string) error { _, err := s.db.Exec("DELETE FROM sessions WHERE token = ?", token) return err } // DeleteAllForUser removes all sessions for a given user (e.g., on password change). func (s *SessionStore) DeleteAllForUser(userID int64) error { _, err := s.db.Exec("DELETE FROM sessions WHERE user_id = ?", userID) return err } // CleanupLoop periodically removes expired sessions. It runs until the // context is canceled. func (s *SessionStore) CleanupLoop(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: result, err := s.db.Exec( `DELETE FROM sessions WHERE expires_at < strftime('%Y-%m-%dT%H:%M:%SZ', 'now')`, ) if err != nil { slog.Error("session cleanup failed", "error", err) continue } n, _ := result.RowsAffected() if n > 0 { slog.Info("cleaned up expired sessions", "count", n) } } } }