package session import ( "context" "log/slog" "sync" "time" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth" ) // IdleTimeout is the default time-since-last-activity after which a // session is reaped. const IdleTimeout = 15 * time.Minute // ReapInterval is how often the reaper sweeps the session map. const ReapInterval = 30 * time.Second // ForgejoRefreshLeadTime is how far before Forgejo-token expiry the // rotator proactively swaps the upstream token. Five minutes is enough // slack for tokens granted with sub-hour TTLs while still being short // enough that we don't refresh excessively for long-lived ones. const ForgejoRefreshLeadTime = 5 * time.Minute // RotateInterval is how often the rotator scans for sessions whose // Forgejo tokens need refreshing. const RotateInterval = 1 * time.Minute // ReaperConfig bundles the inputs to StartReaper. All durations have // sensible defaults if zero. type ReaperConfig struct { IdleTimeout time.Duration ReapInterval time.Duration RotateInterval time.Duration RefreshLead time.Duration // RefreshForgejo is called for each session whose upstream token is // approaching expiry. The implementation refreshes via the Forgejo // OAuth client, persists the new token in the access_tokens row, and // returns the new token+expiry so the rotator can hand them to a // freshly-spawned child. nil disables rotation. RefreshForgejo func(ctx context.Context, sess *oauth.Session) (newAccess, newRefresh string, expiresAt time.Time, err error) // Respawn is called when a session's upstream token has been refreshed. // The implementation spawns a new Backend with the updated token and // returns it; the reaper swaps it in atomically. Respawn SpawnFunc } // StartReaper kicks off the idle-eviction and Forgejo-token-rotation // goroutines. Returns a stop function the caller invokes at shutdown. func (r *Registry) StartReaper(cfg ReaperConfig) (stop func()) { idle := nonZero(cfg.IdleTimeout, IdleTimeout) tick := nonZero(cfg.ReapInterval, ReapInterval) rotateTick := nonZero(cfg.RotateInterval, RotateInterval) lead := nonZero(cfg.RefreshLead, ForgejoRefreshLeadTime) stopCh := make(chan struct{}) var once sync.Once go r.reapLoop(stopCh, tick, idle) if cfg.RefreshForgejo != nil && cfg.Respawn != nil { go r.rotateLoop(stopCh, rotateTick, lead, cfg.RefreshForgejo, cfg.Respawn) } return func() { once.Do(func() { close(stopCh) }) } } func (r *Registry) reapLoop(stop <-chan struct{}, interval, idle time.Duration) { t := time.NewTicker(interval) defer t.Stop() for { select { case <-stop: return case <-t.C: r.reapIdle(idle) } } } func (r *Registry) reapIdle(idle time.Duration) { cutoff := r.now().Add(-idle).UnixNano() r.sessions.Range(func(k, v any) bool { e := v.(*entry) if e.lastActive.Load() < cutoff { r.evict(e) } return true }) } // evict removes the session from the registry and SIGTERMs its current // backend. Used by both the idle reaper and the Forgejo-token rotator. func (r *Registry) evict(e *entry) { if _, ok := r.sessions.LoadAndDelete(e.sid); !ok { return // already gone } r.count.Add(-1) user := e.snapshotOAuth().ForgejoUsername r.log.Info("session reaped", slog.String("sid", e.sid), slog.String("user", user)) stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := e.backend.Load().Stop(stopCtx); err != nil { r.log.Warn("session stop on evict", slog.String("sid", e.sid), slog.String("err", err.Error())) } } func (r *Registry) rotateLoop( stop <-chan struct{}, interval, lead time.Duration, refresh func(context.Context, *oauth.Session) (string, string, time.Time, error), respawn SpawnFunc, ) { t := time.NewTicker(interval) defer t.Stop() for { select { case <-stop: return case <-t.C: r.rotateExpiring(lead, refresh, respawn) } } } func (r *Registry) rotateExpiring( lead time.Duration, refresh func(context.Context, *oauth.Session) (string, string, time.Time, error), respawn SpawnFunc, ) { cutoff := r.now().Add(lead) var due []*entry r.sessions.Range(func(k, v any) bool { e := v.(*entry) if e.snapshotOAuth().ForgejoTokenExp.Before(cutoff) { due = append(due, e) } return true }) for _, e := range due { sess := e.snapshotOAuth() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) newAccess, newRefresh, expiresAt, err := refresh(ctx, sess) cancel() if err != nil { r.log.Warn("forgejo refresh failed", slog.String("sid", e.sid), slog.String("user", sess.ForgejoUsername), slog.String("err", err.Error())) // On refresh failure, evict so the next /mcp request from // this user produces a clean re-auth rather than continuing // with a stale token. r.evict(e) continue } r.swapBackend(e, newAccess, newRefresh, expiresAt, respawn) } } // swapBackend replaces e's backend with one spawned for an updated // oauth.Session. The current child is SIGTERMed; the new one inherits // the same sid so the client doesn't notice (other than re-issuing the // MCP initialize handshake on its next request — see design.md §6). func (r *Registry) swapBackend( e *entry, newAccess, newRefresh string, expiresAt time.Time, respawn SpawnFunc, ) { current := e.snapshotOAuth() updated := *current updated.ForgejoToken = newAccess updated.ForgejoRefresh = newRefresh updated.ForgejoTokenExp = expiresAt ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() newBackend, err := respawn(ctx, &updated) if err != nil { r.log.Warn("respawn failed; evicting", slog.String("sid", e.sid), slog.String("err", err.Error())) r.evict(e) return } // Atomic swap: from this point on, /mcp requests dispatch to the new // backend. The old backend's Done watcher (started in spawnSession) // will fire once we Stop it, but compares against e.backend.Load() — // since that now points at newBackend, the watcher is a no-op and // the session survives the rotation. old := e.backend.Swap(newBackend) e.mu.Lock() e.oauthSess = &updated e.mu.Unlock() r.watchBackend(e.sid, newBackend) go func() { stopCtx, c := context.WithTimeout(context.Background(), 5*time.Second) defer c() _ = old.Stop(stopCtx) }() r.log.Info("session rotated", slog.String("sid", e.sid), slog.String("user", updated.ForgejoUsername)) } func nonZero(d, fallback time.Duration) time.Duration { if d > 0 { return d } return fallback }