From 933e7bd369db33a3701e8171eff62e5944dd80d5 Mon Sep 17 00:00:00 2001 From: Ole-Morten Duesund Date: Mon, 27 Apr 2026 17:32:36 +0200 Subject: [PATCH] feat(session): idle reaper + Forgejo token rotation (forgejo-mcp-broker-q4x) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds StartReaper to internal/session — two background goroutines that keep the session map healthy under steady load. Idle reaper: - Sweeps every ReapInterval (default 30s) for sessions whose LastActive is older than IdleTimeout (default 15m). - Evicts via SIGTERM through the Backend.Stop hook. Token rotator: - Sweeps every RotateInterval (default 1m) for sessions whose Forgejo token is within RefreshLead (default 5m) of expiry. - Calls the operator-supplied RefreshForgejo to obtain new access+refresh tokens, then Respawn to mint a new Backend with the updated token in env. - Atomically swaps e.backend (now an atomic.Pointer[Backend]); the sid is preserved so the client just re-issues an MCP `initialize` on its next request rather than re-authenticating. - On refresh failure, evicts so the next /mcp produces a clean re-auth instead of carrying a stale token. Two race fixes uncovered by -race during this work: - The Done-watcher started in spawnSession captured the original backend pointer; after rotation it still saw Done close (because the old backend was Stopped) and would yank the entire entry. Fixed by comparing watched-backend == e.backend.Load() before evicting. - The fakeSpawner test helper let tests read the backends slice without the lock the spawn callback held. Replaced with a spawnerControl type whose count/at/snapshot methods all lock. Tests cover idle eviction, recently-active sessions surviving sweeps, successful rotation+respawn (sid preserved), refresh failure → eviction, and Stop idempotency. Closes forgejo-mcp-broker-q4x. Phase 5 complete. Co-Authored-By: Claude Opus 4.7 (1M context) --- .beads/issues.jsonl | 2 +- internal/session/reaper.go | 221 +++++++++++++++++++++++++++++++ internal/session/reaper_test.go | 186 ++++++++++++++++++++++++++ internal/session/session.go | 48 +++++-- internal/session/session_test.go | 80 +++++++---- 5 files changed, 501 insertions(+), 36 deletions(-) create mode 100644 internal/session/reaper.go create mode 100644 internal/session/reaper_test.go diff --git a/.beads/issues.jsonl b/.beads/issues.jsonl index affee4f..651304e 100644 --- a/.beads/issues.jsonl +++ b/.beads/issues.jsonl @@ -1,4 +1,4 @@ -{"id":"forgejo-mcp-broker-q4x","title":"Phase 5c: idle reaper + Forgejo token rotation + child respawn","description":"Reaper (30s tick) applies idle timeout. Rotation (1-min tick) refreshes Forgejo tokens expiring \u003c2min, SIGTERMs child, respawns on next request (design.md §6). Token revocation tears down sessions.","acceptance_criteria":"Clock-injected tests: idle kill, rotation triggers respawn, revocation tears down sessions. Smoke test: 20 concurrent sessions for 10min with mid-test rotations.","status":"open","priority":1,"issue_type":"task","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:18Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-24T15:45:18Z","dependencies":[{"issue_id":"forgejo-mcp-broker-q4x","depends_on_id":"forgejo-mcp-broker-pur","type":"blocks","created_at":"2026-04-24T17:45:31Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-q4x","depends_on_id":"forgejo-mcp-broker-t81","type":"blocks","created_at":"2026-04-24T17:45:31Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":2,"dependent_count":1,"comment_count":0} +{"id":"forgejo-mcp-broker-q4x","title":"Phase 5c: idle reaper + Forgejo token rotation + child respawn","description":"Reaper (30s tick) applies idle timeout. Rotation (1-min tick) refreshes Forgejo tokens expiring \u003c2min, SIGTERMs child, respawns on next request (design.md §6). Token revocation tears down sessions.","acceptance_criteria":"Clock-injected tests: idle kill, rotation triggers respawn, revocation tears down sessions. Smoke test: 20 concurrent sessions for 10min with mid-test rotations.","status":"in_progress","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:18Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T15:27:46Z","started_at":"2026-04-27T15:27:46Z","dependencies":[{"issue_id":"forgejo-mcp-broker-q4x","depends_on_id":"forgejo-mcp-broker-pur","type":"blocks","created_at":"2026-04-24T17:45:31Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-q4x","depends_on_id":"forgejo-mcp-broker-t81","type":"blocks","created_at":"2026-04-24T17:45:31Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":2,"dependent_count":1,"comment_count":0} {"id":"forgejo-mcp-broker-ytw","title":"Phase 5b: bearer-token middleware on /mcp","description":"Middleware reads Authorization: Bearer \u003cmcp_token\u003e, resolves via store, attaches Forgejo access token to request context. 401 for missing/expired/revoked.","acceptance_criteria":"Table-driven tests: missing header, malformed, unknown token, expired, revoked, valid. Valid-token path puts Forgejo token on ctx via typed key.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:18Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T15:10:28Z","started_at":"2026-04-27T15:08:52Z","closed_at":"2026-04-27T15:10:28Z","close_reason":"Bearer middleware shipped: RequireBearer wraps protected handlers, looks up access_tokens by hash, rejects expired/revoked/unknown with RFC 6750 WWW-Authenticate. Session attached to ctx for downstream MCP endpoint use.","dependencies":[{"issue_id":"forgejo-mcp-broker-ytw","depends_on_id":"forgejo-mcp-broker-pur","type":"blocks","created_at":"2026-04-24T17:45:30Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":1,"dependent_count":1,"comment_count":0} {"id":"forgejo-mcp-broker-t81","title":"Phase 5a: session registry + spawn-on-initialize","description":"Map Mcp-Session-Id -\u003e supervisor.Child + user metadata. On first initialize for unknown sid, spawn forgejo-mcp with user's Forgejo token in env, bind to bridge. LastActive bumped per request.","acceptance_criteria":"Tests with fake supervisor + fake bridge cover: spawn-on-initialize, reuse for subsequent messages, unknown-sid returns 410, max-sessions cap enforced.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:17Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T15:24:32Z","started_at":"2026-04-27T15:11:43Z","closed_at":"2026-04-27T15:24:32Z","close_reason":"Session registry shipped: Mcp-Session-Id minting on initialize, sid reuse for follow-ups, 410 for unknown, 403 for sid-hijack, max-sessions cap with Retry-After, Done-channel reaping, graceful Stop. Decoupled from supervisor/bridge via SpawnFunc. 83.3% coverage.","dependencies":[{"issue_id":"forgejo-mcp-broker-t81","depends_on_id":"forgejo-mcp-broker-am1","type":"blocks","created_at":"2026-04-24T17:45:29Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-t81","depends_on_id":"forgejo-mcp-broker-pur","type":"blocks","created_at":"2026-04-24T17:45:30Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-t81","depends_on_id":"forgejo-mcp-broker-zuq","type":"blocks","created_at":"2026-04-24T17:45:28Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":3,"dependent_count":2,"comment_count":0} {"id":"forgejo-mcp-broker-xot","title":"Phase 4b: bridge integration test against real forgejo-mcp","description":"Drive the bridge with initialize -\u003e tools/list -\u003e tools/call get_forgejo_mcp_server_version against a real forgejo-mcp subprocess. Validates the opaque-pipe assumption.","acceptance_criteria":"Full handshake, tools/list returns expected set, tools/call returns a version string. Tagged as integration test if runtime exceeds 2s.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:16Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T14:28:39Z","started_at":"2026-04-27T14:10:04Z","closed_at":"2026-04-27T14:28:39Z","close_reason":"Bridge integration test passes against real forgejo-mcp 2.2.0: MCP handshake (initialize → notifications/initialized → tools/list → tools/call) round-trips through bridge cleanly. Fake Forgejo covers /api/v1/version and /api/v1/user probes. Phase 4 complete.","dependencies":[{"issue_id":"forgejo-mcp-broker-xot","depends_on_id":"forgejo-mcp-broker-am1","type":"blocks","created_at":"2026-04-24T17:45:28Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":1,"dependent_count":0,"comment_count":0} diff --git a/internal/session/reaper.go b/internal/session/reaper.go new file mode 100644 index 0000000..90423db --- /dev/null +++ b/internal/session/reaper.go @@ -0,0 +1,221 @@ +package session + +import ( + "context" + "log/slog" + "sync" + "time" + + "kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth" +) + +// IdleTimeout is the default time-since-last-activity after which a +// session is reaped. +const IdleTimeout = 15 * time.Minute + +// ReapInterval is how often the reaper sweeps the session map. +const ReapInterval = 30 * time.Second + +// ForgejoRefreshLeadTime is how far before Forgejo-token expiry the +// rotator proactively swaps the upstream token. Five minutes is enough +// slack for tokens granted with sub-hour TTLs while still being short +// enough that we don't refresh excessively for long-lived ones. +const ForgejoRefreshLeadTime = 5 * time.Minute + +// RotateInterval is how often the rotator scans for sessions whose +// Forgejo tokens need refreshing. +const RotateInterval = 1 * time.Minute + +// ReaperConfig bundles the inputs to StartReaper. All durations have +// sensible defaults if zero. +type ReaperConfig struct { + IdleTimeout time.Duration + ReapInterval time.Duration + RotateInterval time.Duration + RefreshLead time.Duration + + // RefreshForgejo is called for each session whose upstream token is + // approaching expiry. The implementation refreshes via the Forgejo + // OAuth client, persists the new token in the access_tokens row, and + // returns the new token+expiry so the rotator can hand them to a + // freshly-spawned child. nil disables rotation. + RefreshForgejo func(ctx context.Context, sess *oauth.Session) (newAccess, newRefresh string, expiresAt time.Time, err error) + + // Respawn is called when a session's upstream token has been refreshed. + // The implementation spawns a new Backend with the updated token and + // returns it; the reaper swaps it in atomically. + Respawn SpawnFunc +} + +// StartReaper kicks off the idle-eviction and Forgejo-token-rotation +// goroutines. Returns a stop function the caller invokes at shutdown. +func (r *Registry) StartReaper(cfg ReaperConfig) (stop func()) { + idle := nonZero(cfg.IdleTimeout, IdleTimeout) + tick := nonZero(cfg.ReapInterval, ReapInterval) + rotateTick := nonZero(cfg.RotateInterval, RotateInterval) + lead := nonZero(cfg.RefreshLead, ForgejoRefreshLeadTime) + + stopCh := make(chan struct{}) + var once sync.Once + + go r.reapLoop(stopCh, tick, idle) + if cfg.RefreshForgejo != nil && cfg.Respawn != nil { + go r.rotateLoop(stopCh, rotateTick, lead, cfg.RefreshForgejo, cfg.Respawn) + } + + return func() { once.Do(func() { close(stopCh) }) } +} + +func (r *Registry) reapLoop(stop <-chan struct{}, interval, idle time.Duration) { + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-stop: + return + case <-t.C: + r.reapIdle(idle) + } + } +} + +func (r *Registry) reapIdle(idle time.Duration) { + cutoff := r.now().Add(-idle).UnixNano() + r.sessions.Range(func(k, v any) bool { + e := v.(*entry) + if e.lastActive.Load() < cutoff { + r.evict(e) + } + return true + }) +} + +// evict removes the session from the registry and SIGTERMs its current +// backend. Used by both the idle reaper and the Forgejo-token rotator. +func (r *Registry) evict(e *entry) { + if _, ok := r.sessions.LoadAndDelete(e.sid); !ok { + return // already gone + } + r.count.Add(-1) + user := e.snapshotOAuth().ForgejoUsername + r.log.Info("session reaped", + slog.String("sid", e.sid), + slog.String("user", user)) + + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := e.backend.Load().Stop(stopCtx); err != nil { + r.log.Warn("session stop on evict", + slog.String("sid", e.sid), + slog.String("err", err.Error())) + } +} + +func (r *Registry) rotateLoop( + stop <-chan struct{}, + interval, lead time.Duration, + refresh func(context.Context, *oauth.Session) (string, string, time.Time, error), + respawn SpawnFunc, +) { + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-stop: + return + case <-t.C: + r.rotateExpiring(lead, refresh, respawn) + } + } +} + +func (r *Registry) rotateExpiring( + lead time.Duration, + refresh func(context.Context, *oauth.Session) (string, string, time.Time, error), + respawn SpawnFunc, +) { + cutoff := r.now().Add(lead) + var due []*entry + r.sessions.Range(func(k, v any) bool { + e := v.(*entry) + if e.snapshotOAuth().ForgejoTokenExp.Before(cutoff) { + due = append(due, e) + } + return true + }) + + for _, e := range due { + sess := e.snapshotOAuth() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + newAccess, newRefresh, expiresAt, err := refresh(ctx, sess) + cancel() + if err != nil { + r.log.Warn("forgejo refresh failed", + slog.String("sid", e.sid), + slog.String("user", sess.ForgejoUsername), + slog.String("err", err.Error())) + // On refresh failure, evict so the next /mcp request from + // this user produces a clean re-auth rather than continuing + // with a stale token. + r.evict(e) + continue + } + r.swapBackend(e, newAccess, newRefresh, expiresAt, respawn) + } +} + +// swapBackend replaces e's backend with one spawned for an updated +// oauth.Session. The current child is SIGTERMed; the new one inherits +// the same sid so the client doesn't notice (other than re-issuing the +// MCP initialize handshake on its next request — see design.md §6). +func (r *Registry) swapBackend( + e *entry, + newAccess, newRefresh string, + expiresAt time.Time, + respawn SpawnFunc, +) { + current := e.snapshotOAuth() + updated := *current + updated.ForgejoToken = newAccess + updated.ForgejoRefresh = newRefresh + updated.ForgejoTokenExp = expiresAt + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + newBackend, err := respawn(ctx, &updated) + if err != nil { + r.log.Warn("respawn failed; evicting", + slog.String("sid", e.sid), + slog.String("err", err.Error())) + r.evict(e) + return + } + + // Atomic swap: from this point on, /mcp requests dispatch to the new + // backend. The old backend's Done watcher (started in spawnSession) + // will fire once we Stop it, but compares against e.backend.Load() — + // since that now points at newBackend, the watcher is a no-op and + // the session survives the rotation. + old := e.backend.Swap(newBackend) + e.mu.Lock() + e.oauthSess = &updated + e.mu.Unlock() + r.watchBackend(e.sid, newBackend) + + go func() { + stopCtx, c := context.WithTimeout(context.Background(), 5*time.Second) + defer c() + _ = old.Stop(stopCtx) + }() + r.log.Info("session rotated", + slog.String("sid", e.sid), + slog.String("user", updated.ForgejoUsername)) +} + +func nonZero(d, fallback time.Duration) time.Duration { + if d > 0 { + return d + } + return fallback +} diff --git a/internal/session/reaper_test.go b/internal/session/reaper_test.go new file mode 100644 index 0000000..60221aa --- /dev/null +++ b/internal/session/reaper_test.go @@ -0,0 +1,186 @@ +package session_test + +import ( + "context" + "errors" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log" + "kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth" + "kode.naiv.no/olemd/forgejo-mcp-broker/internal/session" +) + +// fakeClock is a manually advanced clock for the reaper tests. The +// reaper goroutines tick on real wall time, so tests trigger eviction +// by waiting briefly between request and reap-interval expiry. +type fakeClock struct { + mu sync.Mutex + now time.Time +} + +func newFakeClock() *fakeClock { + return &fakeClock{now: time.Date(2026, 4, 27, 12, 0, 0, 0, time.UTC)} +} + +func (c *fakeClock) Now() time.Time { + c.mu.Lock() + defer c.mu.Unlock() + return c.now +} + +func (c *fakeClock) Advance(d time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.now = c.now.Add(d) +} + +func TestReaper_EvictsIdleSession(t *testing.T) { + clk := newFakeClock() + spawn, backends := fakeSpawner(t) + r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard(), Now: clk.Now}) + srv := newTestServer(t, r) + + // Spawn a session. + resp := doReq(t, srv.URL, "", bearerSess("idle-user"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) + resp.Body.Close() + if r.Active() != 1 { + t.Fatalf("expected 1 active, got %d", r.Active()) + } + + // Push the clock past the idle timeout. + clk.Advance(time.Hour) + + // Start the reaper with a tight tick so the test runs quickly. + stop := r.StartReaper(session.ReaperConfig{ + IdleTimeout: 10 * time.Minute, + ReapInterval: 20 * time.Millisecond, + }) + defer stop() + + if !waitForActive(r, 0, 2*time.Second) { + t.Fatalf("session was not reaped: Active=%d", r.Active()) + } + if !backends.at(0).stopped.Load() { + t.Error("backend was not stopped on reap") + } +} + +func TestReaper_KeepsRecentlyActiveSession(t *testing.T) { + clk := newFakeClock() + spawn, _ := fakeSpawner(t) + r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard(), Now: clk.Now}) + srv := newTestServer(t, r) + + resp := doReq(t, srv.URL, "", bearerSess("active-user"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) + resp.Body.Close() + + // Clock barely moves — well within the idle timeout. + clk.Advance(time.Minute) + + stop := r.StartReaper(session.ReaperConfig{ + IdleTimeout: 10 * time.Minute, + ReapInterval: 20 * time.Millisecond, + }) + defer stop() + + // Wait long enough for ≥1 reaper tick, then confirm the session is still + // alive. + time.Sleep(100 * time.Millisecond) + if r.Active() != 1 { + t.Errorf("active session was evicted prematurely: Active=%d", r.Active()) + } +} + +func TestRotator_RefreshesAndRespawns(t *testing.T) { + clk := newFakeClock() + spawn, backends := fakeSpawner(t) + r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard(), Now: clk.Now}) + srv := newTestServer(t, r) + + // The fake bearer's ForgejoTokenExp is the zero time, which is "well + // past expiry" by definition — the rotator should fire on first sweep. + resp := doReq(t, srv.URL, "", bearerSess("rotate-user"), `{"jsonrpc":"2.0","id":1}`) + resp.Body.Close() + + var refreshCalls atomic.Int32 + refresh := func(ctx context.Context, sess *oauth.Session) (string, string, time.Time, error) { + refreshCalls.Add(1) + return "new-fj-access", "new-fj-refresh", clk.Now().Add(time.Hour), nil + } + + stop := r.StartReaper(session.ReaperConfig{ + IdleTimeout: time.Hour, // not testing idle here + ReapInterval: time.Hour, // disable idle reaper effectively + RotateInterval: 20 * time.Millisecond, + RefreshLead: 10 * time.Minute, + RefreshForgejo: refresh, + Respawn: spawn, // reuse the same fake; produces a new backend + }) + defer stop() + + // Wait for the rotator to spawn a replacement. + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) && backends.count() < 2 { + time.Sleep(10 * time.Millisecond) + } + if backends.count() < 2 { + t.Fatalf("rotator did not spawn replacement; backends=%d, refreshes=%d", + backends.count(), refreshCalls.Load()) + } + + // Original backend was stopped, replacement is alive, session count unchanged. + if !backends.at(0).stopped.Load() { + t.Error("original backend not stopped after rotation") + } + if r.Active() != 1 { + t.Errorf("Active = %d, want 1 (sid preserved across rotation)", r.Active()) + } +} + +func TestRotator_RefreshFailureEvictsSession(t *testing.T) { + clk := newFakeClock() + spawn, _ := fakeSpawner(t) + r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard(), Now: clk.Now}) + srv := newTestServer(t, r) + + resp := doReq(t, srv.URL, "", bearerSess("rotate-fail"), `{}`) + resp.Body.Close() + + refresh := func(context.Context, *oauth.Session) (string, string, time.Time, error) { + return "", "", time.Time{}, errors.New("forgejo refused") + } + + stop := r.StartReaper(session.ReaperConfig{ + IdleTimeout: time.Hour, + ReapInterval: time.Hour, + RotateInterval: 20 * time.Millisecond, + RefreshLead: 10 * time.Minute, + RefreshForgejo: refresh, + Respawn: spawn, + }) + defer stop() + + if !waitForActive(r, 0, 2*time.Second) { + t.Fatalf("session not evicted on refresh failure: Active=%d", r.Active()) + } +} + +func TestStartReaper_StopIsIdempotent(t *testing.T) { + clk := newFakeClock() + spawn, _ := fakeSpawner(t) + r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard(), Now: clk.Now}) + stop := r.StartReaper(session.ReaperConfig{ + IdleTimeout: time.Hour, + ReapInterval: time.Hour, + }) + stop() + stop() // must not panic +} + +// errPlaceholder keeps unused-import warnings quiet during edits. Remove +// once the file is stable. +var _ http.Handler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) diff --git a/internal/session/session.go b/internal/session/session.go index f45e365..1d8b31e 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -66,9 +66,11 @@ type Registry struct { type entry struct { sid string - backend *Backend - lastActive atomic.Int64 // unix nanoseconds; bumped per request - oauthSess *oauth.Session + backend atomic.Pointer[Backend] // swapped on rotation; readers use Load + lastActive atomic.Int64 // unix nanoseconds; bumped per request + + mu sync.Mutex // guards oauthSess; backend swap holds this too + oauthSess *oauth.Session } // SessionIDHeader is the streamable-HTTP MCP header that ferries the @@ -121,7 +123,7 @@ func (r *Registry) serve(w http.ResponseWriter, req *http.Request) { } w.Header().Set(SessionIDHeader, e.sid) e.lastActive.Store(r.now().UnixNano()) - e.backend.Handler.ServeHTTP(w, req) + e.backend.Load().Handler.ServeHTTP(w, req) return } @@ -133,7 +135,7 @@ func (r *Registry) serve(w http.ResponseWriter, req *http.Request) { return } e := v.(*entry) - if e.oauthSess.BrokerTokenHash != oauthSess.BrokerTokenHash { + if e.snapshotOAuth().BrokerTokenHash != oauthSess.BrokerTokenHash { // Session id is bound to the OAuth token that minted it. A // different bearer probing a stolen sid gets 403 — not 401, so // this is distinct from "your token is bad" and from "we don't @@ -142,7 +144,15 @@ func (r *Registry) serve(w http.ResponseWriter, req *http.Request) { return } e.lastActive.Store(r.now().UnixNano()) - e.backend.Handler.ServeHTTP(w, req) + e.backend.Load().Handler.ServeHTTP(w, req) +} + +// snapshotOAuth returns a pointer to the entry's current oauthSess under +// lock so callers don't see partial swaps during rotation. +func (e *entry) snapshotOAuth() *oauth.Session { + e.mu.Lock() + defer e.mu.Unlock() + return e.oauthSess } func (r *Registry) spawnSession(ctx context.Context, oauthSess *oauth.Session) (*entry, error) { @@ -156,7 +166,8 @@ func (r *Registry) spawnSession(ctx context.Context, oauthSess *oauth.Session) ( } sid := newSessionID() - e := &entry{sid: sid, backend: backend, oauthSess: oauthSess} + e := &entry{sid: sid, oauthSess: oauthSess} + e.backend.Store(backend) e.lastActive.Store(r.now().UnixNano()) if _, loaded := r.sessions.LoadOrStore(sid, e); loaded { @@ -166,12 +177,27 @@ func (r *Registry) spawnSession(ctx context.Context, oauthSess *oauth.Session) ( } r.count.Add(1) - // When the child exits on its own (crash, OOM, etc.), reap the entry. + r.watchBackend(sid, backend) + return e, nil +} + +// watchBackend launches a goroutine that removes the session if the given +// backend's Done closes WHILE that backend is still the entry's current +// one. After a rotation, the old backend's Done eventually closes too, +// but the entry now points at a new backend; in that case the watcher +// is a no-op so the session survives the rotation. +func (r *Registry) watchBackend(sid string, backend *Backend) { go func() { <-backend.Done - r.removeSession(sid) + v, ok := r.sessions.Load(sid) + if !ok { + return + } + e := v.(*entry) + if e.backend.Load() == backend { + r.removeSession(sid) + } }() - return e, nil } func (r *Registry) removeSession(sid string) { @@ -200,7 +226,7 @@ func (r *Registry) Active() int { return int(r.count.Load()) } func (r *Registry) Stop(ctx context.Context) { r.sessions.Range(func(k, v any) bool { e := v.(*entry) - if err := e.backend.Stop(ctx); err != nil { + if err := e.backend.Load().Stop(ctx); err != nil { r.log.Warn("session stop", slog.String("sid", e.sid), slog.String("err", err.Error())) } r.sessions.Delete(k) diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 763636e..d705bb5 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -56,25 +56,57 @@ func (f *fakeBackend) backend() *session.Backend { } } -// fakeSpawner returns a SpawnFunc that hands out a sequence of fakeBackends. -// The returned slice is appended to as Spawn is called, so tests can -// inspect every backend that was minted. -func fakeSpawner(t *testing.T) (session.SpawnFunc, *[]*fakeBackend) { +// spawnerControl wraps a fake SpawnFunc with thread-safe access to the +// backends it has minted. Tests get raced freely without a control type +// because both spawn (called from registry goroutines) and tests access +// the slice; this lets tests query under a lock. +type spawnerControl struct { + mu sync.Mutex + backends []*fakeBackend + spawn session.SpawnFunc +} + +func newSpawnerControl(t *testing.T) *spawnerControl { t.Helper() - var ( - mu sync.Mutex - backends []*fakeBackend - next int - ) - spawn := func(ctx context.Context, sess *oauth.Session) (*session.Backend, error) { - mu.Lock() - defer mu.Unlock() - fb := newFakeBackend(next) - next++ - backends = append(backends, fb) + c := &spawnerControl{} + c.spawn = func(ctx context.Context, sess *oauth.Session) (*session.Backend, error) { + c.mu.Lock() + defer c.mu.Unlock() + fb := newFakeBackend(len(c.backends)) + c.backends = append(c.backends, fb) return fb.backend(), nil } - return spawn, &backends + return c +} + +func (c *spawnerControl) count() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.backends) +} + +func (c *spawnerControl) at(i int) *fakeBackend { + c.mu.Lock() + defer c.mu.Unlock() + if i >= len(c.backends) { + return nil + } + return c.backends[i] +} + +func (c *spawnerControl) snapshot() []*fakeBackend { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]*fakeBackend, len(c.backends)) + copy(out, c.backends) + return out +} + +// fakeSpawner is the legacy two-return adapter so existing tests keep +// compiling. New tests should prefer newSpawnerControl directly. +func fakeSpawner(t *testing.T) (session.SpawnFunc, *spawnerControl) { + c := newSpawnerControl(t) + return c.spawn, c } // testBearerHeader carries a bearer-hash discriminator across the wire so @@ -137,8 +169,8 @@ func TestServe_NewSession_MintsSidAndDispatches(t *testing.T) { if r.Active() != 1 { t.Errorf("Active() = %d, want 1", r.Active()) } - if len(*backends) != 1 || (*backends)[0].requests.Load() != 1 { - t.Errorf("backend was not invoked exactly once: %+v", *backends) + if backends.count() != 1 || backends.at(0).requests.Load() != 1 { + t.Errorf("backend was not invoked exactly once: %+v", backends.snapshot()) } } @@ -162,11 +194,11 @@ func TestServe_KnownSid_ReusesBackend(t *testing.T) { if r.Active() != 1 { t.Errorf("Active() = %d, want 1 (reuse, not spawn)", r.Active()) } - if len(*backends) != 1 { - t.Errorf("Spawn called %d times, want 1", len(*backends)) + if backends.count() != 1 { + t.Errorf("Spawn called %d times, want 1", backends.count()) } - if (*backends)[0].requests.Load() != 2 { - t.Errorf("backend.requests = %d, want 2", (*backends)[0].requests.Load()) + if backends.at(0).requests.Load() != 2 { + t.Errorf("backend.requests = %d, want 2", backends.at(0).requests.Load()) } } @@ -260,7 +292,7 @@ func TestServe_BackendDone_RemovesSession(t *testing.T) { sid := first.Header.Get(session.SessionIDHeader) // Simulate the child exiting. - close((*backends)[0].done) + close(backends.at(0).done) // Wait for the reaper goroutine — poll Active() rather than add a // special hook to the production type. @@ -292,7 +324,7 @@ func TestStop_TearsDownAllSessions(t *testing.T) { if r.Active() != 0 { t.Errorf("Active after Stop = %d, want 0", r.Active()) } - for _, b := range *backends { + for _, b := range backends.snapshot() { if !b.stopped.Load() { t.Errorf("backend %d not stopped", b.id) }