221 lines
6.4 KiB
Go
221 lines
6.4 KiB
Go
|
|
package session
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"log/slog"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
|
||
|
|
)
|
||
|
|
|
||
|
|
// IdleTimeout is the default time-since-last-activity after which a
|
||
|
|
// session is reaped.
|
||
|
|
const IdleTimeout = 15 * time.Minute
|
||
|
|
|
||
|
|
// ReapInterval is how often the reaper sweeps the session map.
|
||
|
|
const ReapInterval = 30 * time.Second
|
||
|
|
|
||
|
|
// ForgejoRefreshLeadTime is how far before Forgejo-token expiry the
|
||
|
|
// rotator proactively swaps the upstream token. Five minutes is enough
|
||
|
|
// slack for tokens granted with sub-hour TTLs while still being short
|
||
|
|
// enough that we don't refresh excessively for long-lived ones.
|
||
|
|
const ForgejoRefreshLeadTime = 5 * time.Minute
|
||
|
|
|
||
|
|
// RotateInterval is how often the rotator scans for sessions whose
|
||
|
|
// Forgejo tokens need refreshing.
|
||
|
|
const RotateInterval = 1 * time.Minute
|
||
|
|
|
||
|
|
// ReaperConfig bundles the inputs to StartReaper. All durations have
|
||
|
|
// sensible defaults if zero.
|
||
|
|
type ReaperConfig struct {
|
||
|
|
IdleTimeout time.Duration
|
||
|
|
ReapInterval time.Duration
|
||
|
|
RotateInterval time.Duration
|
||
|
|
RefreshLead time.Duration
|
||
|
|
|
||
|
|
// RefreshForgejo is called for each session whose upstream token is
|
||
|
|
// approaching expiry. The implementation refreshes via the Forgejo
|
||
|
|
// OAuth client, persists the new token in the access_tokens row, and
|
||
|
|
// returns the new token+expiry so the rotator can hand them to a
|
||
|
|
// freshly-spawned child. nil disables rotation.
|
||
|
|
RefreshForgejo func(ctx context.Context, sess *oauth.Session) (newAccess, newRefresh string, expiresAt time.Time, err error)
|
||
|
|
|
||
|
|
// Respawn is called when a session's upstream token has been refreshed.
|
||
|
|
// The implementation spawns a new Backend with the updated token and
|
||
|
|
// returns it; the reaper swaps it in atomically.
|
||
|
|
Respawn SpawnFunc
|
||
|
|
}
|
||
|
|
|
||
|
|
// StartReaper kicks off the idle-eviction and Forgejo-token-rotation
|
||
|
|
// goroutines. Returns a stop function the caller invokes at shutdown.
|
||
|
|
func (r *Registry) StartReaper(cfg ReaperConfig) (stop func()) {
|
||
|
|
idle := nonZero(cfg.IdleTimeout, IdleTimeout)
|
||
|
|
tick := nonZero(cfg.ReapInterval, ReapInterval)
|
||
|
|
rotateTick := nonZero(cfg.RotateInterval, RotateInterval)
|
||
|
|
lead := nonZero(cfg.RefreshLead, ForgejoRefreshLeadTime)
|
||
|
|
|
||
|
|
stopCh := make(chan struct{})
|
||
|
|
var once sync.Once
|
||
|
|
|
||
|
|
go r.reapLoop(stopCh, tick, idle)
|
||
|
|
if cfg.RefreshForgejo != nil && cfg.Respawn != nil {
|
||
|
|
go r.rotateLoop(stopCh, rotateTick, lead, cfg.RefreshForgejo, cfg.Respawn)
|
||
|
|
}
|
||
|
|
|
||
|
|
return func() { once.Do(func() { close(stopCh) }) }
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *Registry) reapLoop(stop <-chan struct{}, interval, idle time.Duration) {
|
||
|
|
t := time.NewTicker(interval)
|
||
|
|
defer t.Stop()
|
||
|
|
for {
|
||
|
|
select {
|
||
|
|
case <-stop:
|
||
|
|
return
|
||
|
|
case <-t.C:
|
||
|
|
r.reapIdle(idle)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *Registry) reapIdle(idle time.Duration) {
|
||
|
|
cutoff := r.now().Add(-idle).UnixNano()
|
||
|
|
r.sessions.Range(func(k, v any) bool {
|
||
|
|
e := v.(*entry)
|
||
|
|
if e.lastActive.Load() < cutoff {
|
||
|
|
r.evict(e)
|
||
|
|
}
|
||
|
|
return true
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
// evict removes the session from the registry and SIGTERMs its current
|
||
|
|
// backend. Used by both the idle reaper and the Forgejo-token rotator.
|
||
|
|
func (r *Registry) evict(e *entry) {
|
||
|
|
if _, ok := r.sessions.LoadAndDelete(e.sid); !ok {
|
||
|
|
return // already gone
|
||
|
|
}
|
||
|
|
r.count.Add(-1)
|
||
|
|
user := e.snapshotOAuth().ForgejoUsername
|
||
|
|
r.log.Info("session reaped",
|
||
|
|
slog.String("sid", e.sid),
|
||
|
|
slog.String("user", user))
|
||
|
|
|
||
|
|
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
|
|
defer cancel()
|
||
|
|
if err := e.backend.Load().Stop(stopCtx); err != nil {
|
||
|
|
r.log.Warn("session stop on evict",
|
||
|
|
slog.String("sid", e.sid),
|
||
|
|
slog.String("err", err.Error()))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *Registry) rotateLoop(
|
||
|
|
stop <-chan struct{},
|
||
|
|
interval, lead time.Duration,
|
||
|
|
refresh func(context.Context, *oauth.Session) (string, string, time.Time, error),
|
||
|
|
respawn SpawnFunc,
|
||
|
|
) {
|
||
|
|
t := time.NewTicker(interval)
|
||
|
|
defer t.Stop()
|
||
|
|
for {
|
||
|
|
select {
|
||
|
|
case <-stop:
|
||
|
|
return
|
||
|
|
case <-t.C:
|
||
|
|
r.rotateExpiring(lead, refresh, respawn)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *Registry) rotateExpiring(
|
||
|
|
lead time.Duration,
|
||
|
|
refresh func(context.Context, *oauth.Session) (string, string, time.Time, error),
|
||
|
|
respawn SpawnFunc,
|
||
|
|
) {
|
||
|
|
cutoff := r.now().Add(lead)
|
||
|
|
var due []*entry
|
||
|
|
r.sessions.Range(func(k, v any) bool {
|
||
|
|
e := v.(*entry)
|
||
|
|
if e.snapshotOAuth().ForgejoTokenExp.Before(cutoff) {
|
||
|
|
due = append(due, e)
|
||
|
|
}
|
||
|
|
return true
|
||
|
|
})
|
||
|
|
|
||
|
|
for _, e := range due {
|
||
|
|
sess := e.snapshotOAuth()
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
|
|
newAccess, newRefresh, expiresAt, err := refresh(ctx, sess)
|
||
|
|
cancel()
|
||
|
|
if err != nil {
|
||
|
|
r.log.Warn("forgejo refresh failed",
|
||
|
|
slog.String("sid", e.sid),
|
||
|
|
slog.String("user", sess.ForgejoUsername),
|
||
|
|
slog.String("err", err.Error()))
|
||
|
|
// On refresh failure, evict so the next /mcp request from
|
||
|
|
// this user produces a clean re-auth rather than continuing
|
||
|
|
// with a stale token.
|
||
|
|
r.evict(e)
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
r.swapBackend(e, newAccess, newRefresh, expiresAt, respawn)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// swapBackend replaces e's backend with one spawned for an updated
|
||
|
|
// oauth.Session. The current child is SIGTERMed; the new one inherits
|
||
|
|
// the same sid so the client doesn't notice (other than re-issuing the
|
||
|
|
// MCP initialize handshake on its next request — see design.md §6).
|
||
|
|
func (r *Registry) swapBackend(
|
||
|
|
e *entry,
|
||
|
|
newAccess, newRefresh string,
|
||
|
|
expiresAt time.Time,
|
||
|
|
respawn SpawnFunc,
|
||
|
|
) {
|
||
|
|
current := e.snapshotOAuth()
|
||
|
|
updated := *current
|
||
|
|
updated.ForgejoToken = newAccess
|
||
|
|
updated.ForgejoRefresh = newRefresh
|
||
|
|
updated.ForgejoTokenExp = expiresAt
|
||
|
|
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
newBackend, err := respawn(ctx, &updated)
|
||
|
|
if err != nil {
|
||
|
|
r.log.Warn("respawn failed; evicting",
|
||
|
|
slog.String("sid", e.sid),
|
||
|
|
slog.String("err", err.Error()))
|
||
|
|
r.evict(e)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Atomic swap: from this point on, /mcp requests dispatch to the new
|
||
|
|
// backend. The old backend's Done watcher (started in spawnSession)
|
||
|
|
// will fire once we Stop it, but compares against e.backend.Load() —
|
||
|
|
// since that now points at newBackend, the watcher is a no-op and
|
||
|
|
// the session survives the rotation.
|
||
|
|
old := e.backend.Swap(newBackend)
|
||
|
|
e.mu.Lock()
|
||
|
|
e.oauthSess = &updated
|
||
|
|
e.mu.Unlock()
|
||
|
|
r.watchBackend(e.sid, newBackend)
|
||
|
|
|
||
|
|
go func() {
|
||
|
|
stopCtx, c := context.WithTimeout(context.Background(), 5*time.Second)
|
||
|
|
defer c()
|
||
|
|
_ = old.Stop(stopCtx)
|
||
|
|
}()
|
||
|
|
r.log.Info("session rotated",
|
||
|
|
slog.String("sid", e.sid),
|
||
|
|
slog.String("user", updated.ForgejoUsername))
|
||
|
|
}
|
||
|
|
|
||
|
|
func nonZero(d, fallback time.Duration) time.Duration {
|
||
|
|
if d > 0 {
|
||
|
|
return d
|
||
|
|
}
|
||
|
|
return fallback
|
||
|
|
}
|