119 lines
3 KiB
Go
119 lines
3 KiB
Go
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
|
|
||
|
|
package store
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"crypto/rand"
|
||
|
|
"database/sql"
|
||
|
|
"encoding/hex"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"log/slog"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/favoritter/internal/model"
|
||
|
|
)
|
||
|
|
|
||
|
|
var ErrSessionNotFound = errors.New("session not found")
|
||
|
|
|
||
|
|
type SessionStore struct {
|
||
|
|
db *sql.DB
|
||
|
|
lifetime time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewSessionStore(db *sql.DB) *SessionStore {
|
||
|
|
return &SessionStore{db: db, lifetime: 720 * time.Hour} // default 30 days
|
||
|
|
}
|
||
|
|
|
||
|
|
// SetLifetime configures the session lifetime.
|
||
|
|
func (s *SessionStore) SetLifetime(d time.Duration) {
|
||
|
|
s.lifetime = d
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create generates a new session token for the given user.
|
||
|
|
func (s *SessionStore) Create(userID int64) (string, error) {
|
||
|
|
tokenBytes := make([]byte, 32)
|
||
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
||
|
|
return "", fmt.Errorf("generate session token: %w", err)
|
||
|
|
}
|
||
|
|
token := hex.EncodeToString(tokenBytes)
|
||
|
|
|
||
|
|
expiresAt := time.Now().UTC().Add(s.lifetime)
|
||
|
|
_, err := s.db.Exec(
|
||
|
|
`INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)`,
|
||
|
|
token, userID, expiresAt.Format(time.RFC3339),
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return "", fmt.Errorf("insert session: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return token, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Validate checks if a session token is valid and not expired.
|
||
|
|
// Returns the session if valid.
|
||
|
|
func (s *SessionStore) Validate(token string) (*model.Session, error) {
|
||
|
|
var session model.Session
|
||
|
|
var expiresAt, createdAt string
|
||
|
|
err := s.db.QueryRow(
|
||
|
|
`SELECT token, user_id, expires_at, created_at
|
||
|
|
FROM sessions WHERE token = ?`, token,
|
||
|
|
).Scan(&session.Token, &session.UserID, &expiresAt, &createdAt)
|
||
|
|
|
||
|
|
if errors.Is(err, sql.ErrNoRows) {
|
||
|
|
return nil, ErrSessionNotFound
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("query session: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
session.ExpiresAt, _ = time.Parse(time.RFC3339, expiresAt)
|
||
|
|
session.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||
|
|
|
||
|
|
if time.Now().UTC().After(session.ExpiresAt) {
|
||
|
|
// Session has expired — delete it.
|
||
|
|
s.Delete(token)
|
||
|
|
return nil, ErrSessionNotFound
|
||
|
|
}
|
||
|
|
|
||
|
|
return &session, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Delete removes a session by its token.
|
||
|
|
func (s *SessionStore) Delete(token string) error {
|
||
|
|
_, err := s.db.Exec("DELETE FROM sessions WHERE token = ?", token)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// DeleteAllForUser removes all sessions for a given user (e.g., on password change).
|
||
|
|
func (s *SessionStore) DeleteAllForUser(userID int64) error {
|
||
|
|
_, err := s.db.Exec("DELETE FROM sessions WHERE user_id = ?", userID)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// CleanupLoop periodically removes expired sessions. It runs until the
|
||
|
|
// context is canceled.
|
||
|
|
func (s *SessionStore) CleanupLoop(ctx context.Context, interval time.Duration) {
|
||
|
|
ticker := time.NewTicker(interval)
|
||
|
|
defer ticker.Stop()
|
||
|
|
|
||
|
|
for {
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
return
|
||
|
|
case <-ticker.C:
|
||
|
|
result, err := s.db.Exec(
|
||
|
|
`DELETE FROM sessions WHERE expires_at < strftime('%Y-%m-%dT%H:%M:%SZ', 'now')`,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
slog.Error("session cleanup failed", "error", err)
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
n, _ := result.RowsAffected()
|
||
|
|
if n > 0 {
|
||
|
|
slog.Info("cleaned up expired sessions", "count", n)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|