forgejo-mcp-broker/internal/session/session_test.go
Ole-Morten Duesund 933e7bd369 feat(session): idle reaper + Forgejo token rotation (forgejo-mcp-broker-q4x)
Adds StartReaper to internal/session — two background goroutines that
keep the session map healthy under steady load.

Idle reaper:
  - Sweeps every ReapInterval (default 30s) for sessions whose
    LastActive is older than IdleTimeout (default 15m).
  - Evicts via SIGTERM through the Backend.Stop hook.

Token rotator:
  - Sweeps every RotateInterval (default 1m) for sessions whose Forgejo
    token is within RefreshLead (default 5m) of expiry.
  - Calls the operator-supplied RefreshForgejo to obtain new
    access+refresh tokens, then Respawn to mint a new Backend with the
    updated token in env.
  - Atomically swaps e.backend (now an atomic.Pointer[Backend]); the
    sid is preserved so the client just re-issues an MCP `initialize`
    on its next request rather than re-authenticating.
  - On refresh failure, evicts so the next /mcp produces a clean
    re-auth instead of carrying a stale token.

Two race fixes uncovered by -race during this work:
  - The Done-watcher started in spawnSession captured the original
    backend pointer; after rotation it still saw Done close (because
    the old backend was Stopped) and would yank the entire entry. Fixed
    by comparing watched-backend == e.backend.Load() before evicting.
  - The fakeSpawner test helper let tests read the backends slice
    without the lock the spawn callback held. Replaced with a
    spawnerControl type whose count/at/snapshot methods all lock.

Tests cover idle eviction, recently-active sessions surviving sweeps,
successful rotation+respawn (sid preserved), refresh failure → eviction,
and Stop idempotency.

Closes forgejo-mcp-broker-q4x. Phase 5 complete.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 17:32:36 +02:00

377 lines
11 KiB
Go

package session_test
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log"
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/session"
)
// fakeBackend is a controllable session.Backend used by all tests in this
// package. It records every Handler invocation and exposes a Done channel
// the tests can close to simulate the child exiting.
type fakeBackend struct {
id int
handler func(w http.ResponseWriter, r *http.Request)
done chan struct{}
stopErr error
stopped atomic.Bool
requests atomic.Int32
}
func newFakeBackend(id int) *fakeBackend {
return &fakeBackend{id: id, done: make(chan struct{})}
}
func (f *fakeBackend) backend() *session.Backend {
return &session.Backend{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f.requests.Add(1)
if f.handler != nil {
f.handler(w, r)
return
}
fmt.Fprintf(w, "backend-%d-served", f.id)
}),
Stop: func(ctx context.Context) error {
f.stopped.Store(true)
select {
case <-f.done:
default:
close(f.done)
}
return f.stopErr
},
Done: f.done,
}
}
// spawnerControl wraps a fake SpawnFunc with thread-safe access to the
// backends it has minted. Tests get raced freely without a control type
// because both spawn (called from registry goroutines) and tests access
// the slice; this lets tests query under a lock.
type spawnerControl struct {
mu sync.Mutex
backends []*fakeBackend
spawn session.SpawnFunc
}
func newSpawnerControl(t *testing.T) *spawnerControl {
t.Helper()
c := &spawnerControl{}
c.spawn = func(ctx context.Context, sess *oauth.Session) (*session.Backend, error) {
c.mu.Lock()
defer c.mu.Unlock()
fb := newFakeBackend(len(c.backends))
c.backends = append(c.backends, fb)
return fb.backend(), nil
}
return c
}
func (c *spawnerControl) count() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.backends)
}
func (c *spawnerControl) at(i int) *fakeBackend {
c.mu.Lock()
defer c.mu.Unlock()
if i >= len(c.backends) {
return nil
}
return c.backends[i]
}
func (c *spawnerControl) snapshot() []*fakeBackend {
c.mu.Lock()
defer c.mu.Unlock()
out := make([]*fakeBackend, len(c.backends))
copy(out, c.backends)
return out
}
// fakeSpawner is the legacy two-return adapter so existing tests keep
// compiling. New tests should prefer newSpawnerControl directly.
func fakeSpawner(t *testing.T) (session.SpawnFunc, *spawnerControl) {
c := newSpawnerControl(t)
return c.spawn, c
}
// testBearerHeader carries a bearer-hash discriminator across the wire so
// the test server can swap in the right oauth.Session per request. We
// can't propagate context from client to server through net/http, so this
// header substitutes for what RequireBearer would otherwise inject.
const testBearerHeader = "X-Test-Bearer-Hash"
// newTestServer wraps the Registry handler with a tiny middleware that
// reads testBearerHeader and attaches a matching oauth.Session to the
// request context. Without the header (i.e. simulating no auth), the
// registry sees a context without a session and returns 500 — exactly
// the production behaviour for unauthenticated /mcp traffic.
func newTestServer(t *testing.T, r *session.Registry) *httptest.Server {
t.Helper()
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if hash := req.Header.Get(testBearerHeader); hash != "" {
ctx := oauth.ContextWithSession(req.Context(), bearerSess(hash))
req = req.WithContext(ctx)
}
r.Handler().ServeHTTP(w, req)
})
srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
return srv
}
// helper used at the package level so tests don't have to construct
// oauth.Session manually each time.
func bearerSess(hash string) *oauth.Session {
return &oauth.Session{
ClientID: "client-" + hash,
ForgejoUsername: "user-" + hash,
BrokerTokenHash: hash,
ForgejoToken: "fj-token-" + hash,
Scopes: "read:user",
}
}
func TestServe_NewSession_MintsSidAndDispatches(t *testing.T) {
spawn, backends := fakeSpawner(t)
r, err := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
if err != nil {
t.Fatalf("New: %v", err)
}
srv := newTestServer(t, r)
resp := doReq(t, srv.URL, "", bearerSess("hash-A"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("status = %d, want 200; body: %s", resp.StatusCode, body)
}
sid := resp.Header.Get(session.SessionIDHeader)
if sid == "" {
t.Error("response missing Mcp-Session-Id header")
}
if r.Active() != 1 {
t.Errorf("Active() = %d, want 1", r.Active())
}
if backends.count() != 1 || backends.at(0).requests.Load() != 1 {
t.Errorf("backend was not invoked exactly once: %+v", backends.snapshot())
}
}
func TestServe_KnownSid_ReusesBackend(t *testing.T) {
spawn, backends := fakeSpawner(t)
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
srv := newTestServer(t, r)
bearer := bearerSess("hash-B")
resp1 := doReq(t, srv.URL, "", bearer, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
resp1.Body.Close()
sid := resp1.Header.Get(session.SessionIDHeader)
resp2 := doReq(t, srv.URL, sid, bearer, `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`)
resp2.Body.Close()
if resp2.StatusCode != http.StatusOK {
t.Errorf("second request status = %d, want 200", resp2.StatusCode)
}
if r.Active() != 1 {
t.Errorf("Active() = %d, want 1 (reuse, not spawn)", r.Active())
}
if backends.count() != 1 {
t.Errorf("Spawn called %d times, want 1", backends.count())
}
if backends.at(0).requests.Load() != 2 {
t.Errorf("backend.requests = %d, want 2", backends.at(0).requests.Load())
}
}
func TestServe_UnknownSid_410(t *testing.T) {
spawn, _ := fakeSpawner(t)
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
srv := newTestServer(t, r)
resp := doReq(t, srv.URL, "definitely-not-a-real-sid", bearerSess("hash-C"), `{}`)
resp.Body.Close()
if resp.StatusCode != http.StatusGone {
t.Errorf("status = %d, want 410", resp.StatusCode)
}
}
func TestServe_TokenMismatch_403(t *testing.T) {
// Two different bearer hashes for the same sid: only the original
// owner can access.
spawn, _ := fakeSpawner(t)
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
srv := newTestServer(t, r)
first := doReq(t, srv.URL, "", bearerSess("alice"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
first.Body.Close()
sid := first.Header.Get(session.SessionIDHeader)
hijack := doReq(t, srv.URL, sid, bearerSess("eve"), `{"jsonrpc":"2.0","id":2}`)
hijack.Body.Close()
if hijack.StatusCode != http.StatusForbidden {
t.Errorf("status = %d, want 403", hijack.StatusCode)
}
}
func TestServe_NoAuthSessionInContext_500(t *testing.T) {
// Calling /mcp without going through RequireBearer first is a
// programmer error; the registry surfaces it loudly rather than
// silently spawning an unauthenticated session.
spawn, _ := fakeSpawner(t)
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
srv := newTestServer(t, r)
resp, err := http.Post(srv.URL, "application/json", strings.NewReader(`{}`))
if err != nil {
t.Fatalf("post: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("status = %d, want 500", resp.StatusCode)
}
}
func TestServe_MaxSessionsCap(t *testing.T) {
spawn, _ := fakeSpawner(t)
r, _ := session.New(session.Config{
Spawn: spawn, Log: brokerlog.Discard(), MaxSessions: 2,
})
srv := newTestServer(t, r)
// Two sessions allowed.
for i, hash := range []string{"a", "b"} {
resp := doReq(t, srv.URL, "", bearerSess(hash), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("session %d: status = %d", i, resp.StatusCode)
}
}
// Third hits the cap.
resp := doReq(t, srv.URL, "", bearerSess("c"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Errorf("third session status = %d, want 503", resp.StatusCode)
}
if resp.Header.Get("Retry-After") == "" {
t.Error("503 response should include Retry-After")
}
if r.Active() != 2 {
t.Errorf("Active() = %d, want 2", r.Active())
}
}
func TestServe_BackendDone_RemovesSession(t *testing.T) {
// When the child exits on its own (Done closes), the registry should
// reap the entry so subsequent traffic with that sid gets 410.
spawn, backends := fakeSpawner(t)
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
srv := newTestServer(t, r)
bearer := bearerSess("crashed")
first := doReq(t, srv.URL, "", bearer, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
first.Body.Close()
sid := first.Header.Get(session.SessionIDHeader)
// Simulate the child exiting.
close(backends.at(0).done)
// Wait for the reaper goroutine — poll Active() rather than add a
// special hook to the production type.
if !waitForActive(r, 0, 2*time.Second) {
t.Fatalf("session count never dropped to 0")
}
resp := doReq(t, srv.URL, sid, bearer, `{}`)
resp.Body.Close()
if resp.StatusCode != http.StatusGone {
t.Errorf("after backend.Done, status = %d, want 410", resp.StatusCode)
}
}
func TestStop_TearsDownAllSessions(t *testing.T) {
spawn, backends := fakeSpawner(t)
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
srv := newTestServer(t, r)
for _, hash := range []string{"x", "y", "z"} {
resp := doReq(t, srv.URL, "", bearerSess(hash), `{}`)
resp.Body.Close()
}
if r.Active() != 3 {
t.Fatalf("Active = %d, want 3", r.Active())
}
r.Stop(context.Background())
if r.Active() != 0 {
t.Errorf("Active after Stop = %d, want 0", r.Active())
}
for _, b := range backends.snapshot() {
if !b.stopped.Load() {
t.Errorf("backend %d not stopped", b.id)
}
}
}
func TestNew_RequiresSpawn(t *testing.T) {
_, err := session.New(session.Config{Log: brokerlog.Discard()})
if err == nil || !strings.Contains(err.Error(), "Spawn") {
t.Errorf("want Spawn-required error, got %v", err)
}
}
// doReq POSTs to the test server with an optional session id and a
// pseudo-bearer hash that the test middleware translates into an
// oauth.Session. Pass an empty hash to simulate an unauthenticated
// request.
func doReq(t *testing.T, url, sid string, sess *oauth.Session, body string) *http.Response {
t.Helper()
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, url, strings.NewReader(body))
if err != nil {
t.Fatalf("new request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
if sid != "" {
req.Header.Set(session.SessionIDHeader, sid)
}
if sess != nil {
req.Header.Set(testBearerHeader, sess.BrokerTokenHash)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("do: %v", err)
}
return resp
}
// waitForActive polls r.Active() until it equals target or the deadline
// expires. Returns true on a hit. Used by the backend-Done test to give
// the registry's reaper goroutine time to run.
func waitForActive(r *session.Registry, target int, within time.Duration) bool {
deadline := time.Now().Add(within)
for time.Now().Before(deadline) {
if r.Active() == target {
return true
}
time.Sleep(5 * time.Millisecond)
}
return false
}