Adds StartReaper to internal/session — two background goroutines that
keep the session map healthy under steady load.
Idle reaper:
- Sweeps every ReapInterval (default 30s) for sessions whose
LastActive is older than IdleTimeout (default 15m).
- Evicts via SIGTERM through the Backend.Stop hook.
Token rotator:
- Sweeps every RotateInterval (default 1m) for sessions whose Forgejo
token is within RefreshLead (default 5m) of expiry.
- Calls the operator-supplied RefreshForgejo to obtain new
access+refresh tokens, then Respawn to mint a new Backend with the
updated token in env.
- Atomically swaps e.backend (now an atomic.Pointer[Backend]); the
sid is preserved so the client just re-issues an MCP `initialize`
on its next request rather than re-authenticating.
- On refresh failure, evicts so the next /mcp produces a clean
re-auth instead of carrying a stale token.
Two race fixes uncovered by -race during this work:
- The Done-watcher started in spawnSession captured the original
backend pointer; after rotation it still saw Done close (because
the old backend was Stopped) and would yank the entire entry. Fixed
by comparing watched-backend == e.backend.Load() before evicting.
- The fakeSpawner test helper let tests read the backends slice
without the lock the spawn callback held. Replaced with a
spawnerControl type whose count/at/snapshot methods all lock.
Tests cover idle eviction, recently-active sessions surviving sweeps,
successful rotation+respawn (sid preserved), refresh failure → eviction,
and Stop idempotency.
Closes forgejo-mcp-broker-q4x. Phase 5 complete.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
377 lines
11 KiB
Go
377 lines
11 KiB
Go
package session_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log"
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/session"
|
|
)
|
|
|
|
// fakeBackend is a controllable session.Backend used by all tests in this
|
|
// package. It records every Handler invocation and exposes a Done channel
|
|
// the tests can close to simulate the child exiting.
|
|
type fakeBackend struct {
|
|
id int
|
|
handler func(w http.ResponseWriter, r *http.Request)
|
|
done chan struct{}
|
|
stopErr error
|
|
stopped atomic.Bool
|
|
requests atomic.Int32
|
|
}
|
|
|
|
func newFakeBackend(id int) *fakeBackend {
|
|
return &fakeBackend{id: id, done: make(chan struct{})}
|
|
}
|
|
|
|
func (f *fakeBackend) backend() *session.Backend {
|
|
return &session.Backend{
|
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
f.requests.Add(1)
|
|
if f.handler != nil {
|
|
f.handler(w, r)
|
|
return
|
|
}
|
|
fmt.Fprintf(w, "backend-%d-served", f.id)
|
|
}),
|
|
Stop: func(ctx context.Context) error {
|
|
f.stopped.Store(true)
|
|
select {
|
|
case <-f.done:
|
|
default:
|
|
close(f.done)
|
|
}
|
|
return f.stopErr
|
|
},
|
|
Done: f.done,
|
|
}
|
|
}
|
|
|
|
// spawnerControl wraps a fake SpawnFunc with thread-safe access to the
|
|
// backends it has minted. Tests get raced freely without a control type
|
|
// because both spawn (called from registry goroutines) and tests access
|
|
// the slice; this lets tests query under a lock.
|
|
type spawnerControl struct {
|
|
mu sync.Mutex
|
|
backends []*fakeBackend
|
|
spawn session.SpawnFunc
|
|
}
|
|
|
|
func newSpawnerControl(t *testing.T) *spawnerControl {
|
|
t.Helper()
|
|
c := &spawnerControl{}
|
|
c.spawn = func(ctx context.Context, sess *oauth.Session) (*session.Backend, error) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
fb := newFakeBackend(len(c.backends))
|
|
c.backends = append(c.backends, fb)
|
|
return fb.backend(), nil
|
|
}
|
|
return c
|
|
}
|
|
|
|
func (c *spawnerControl) count() int {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return len(c.backends)
|
|
}
|
|
|
|
func (c *spawnerControl) at(i int) *fakeBackend {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if i >= len(c.backends) {
|
|
return nil
|
|
}
|
|
return c.backends[i]
|
|
}
|
|
|
|
func (c *spawnerControl) snapshot() []*fakeBackend {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
out := make([]*fakeBackend, len(c.backends))
|
|
copy(out, c.backends)
|
|
return out
|
|
}
|
|
|
|
// fakeSpawner is the legacy two-return adapter so existing tests keep
|
|
// compiling. New tests should prefer newSpawnerControl directly.
|
|
func fakeSpawner(t *testing.T) (session.SpawnFunc, *spawnerControl) {
|
|
c := newSpawnerControl(t)
|
|
return c.spawn, c
|
|
}
|
|
|
|
// testBearerHeader carries a bearer-hash discriminator across the wire so
|
|
// the test server can swap in the right oauth.Session per request. We
|
|
// can't propagate context from client to server through net/http, so this
|
|
// header substitutes for what RequireBearer would otherwise inject.
|
|
const testBearerHeader = "X-Test-Bearer-Hash"
|
|
|
|
// newTestServer wraps the Registry handler with a tiny middleware that
|
|
// reads testBearerHeader and attaches a matching oauth.Session to the
|
|
// request context. Without the header (i.e. simulating no auth), the
|
|
// registry sees a context without a session and returns 500 — exactly
|
|
// the production behaviour for unauthenticated /mcp traffic.
|
|
func newTestServer(t *testing.T, r *session.Registry) *httptest.Server {
|
|
t.Helper()
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
if hash := req.Header.Get(testBearerHeader); hash != "" {
|
|
ctx := oauth.ContextWithSession(req.Context(), bearerSess(hash))
|
|
req = req.WithContext(ctx)
|
|
}
|
|
r.Handler().ServeHTTP(w, req)
|
|
})
|
|
srv := httptest.NewServer(handler)
|
|
t.Cleanup(srv.Close)
|
|
return srv
|
|
}
|
|
|
|
// helper used at the package level so tests don't have to construct
|
|
// oauth.Session manually each time.
|
|
func bearerSess(hash string) *oauth.Session {
|
|
return &oauth.Session{
|
|
ClientID: "client-" + hash,
|
|
ForgejoUsername: "user-" + hash,
|
|
BrokerTokenHash: hash,
|
|
ForgejoToken: "fj-token-" + hash,
|
|
Scopes: "read:user",
|
|
}
|
|
}
|
|
|
|
func TestServe_NewSession_MintsSidAndDispatches(t *testing.T) {
|
|
spawn, backends := fakeSpawner(t)
|
|
r, err := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
if err != nil {
|
|
t.Fatalf("New: %v", err)
|
|
}
|
|
|
|
srv := newTestServer(t, r)
|
|
|
|
resp := doReq(t, srv.URL, "", bearerSess("hash-A"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("status = %d, want 200; body: %s", resp.StatusCode, body)
|
|
}
|
|
sid := resp.Header.Get(session.SessionIDHeader)
|
|
if sid == "" {
|
|
t.Error("response missing Mcp-Session-Id header")
|
|
}
|
|
if r.Active() != 1 {
|
|
t.Errorf("Active() = %d, want 1", r.Active())
|
|
}
|
|
if backends.count() != 1 || backends.at(0).requests.Load() != 1 {
|
|
t.Errorf("backend was not invoked exactly once: %+v", backends.snapshot())
|
|
}
|
|
}
|
|
|
|
func TestServe_KnownSid_ReusesBackend(t *testing.T) {
|
|
spawn, backends := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
bearer := bearerSess("hash-B")
|
|
|
|
resp1 := doReq(t, srv.URL, "", bearer, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
resp1.Body.Close()
|
|
sid := resp1.Header.Get(session.SessionIDHeader)
|
|
|
|
resp2 := doReq(t, srv.URL, sid, bearer, `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`)
|
|
resp2.Body.Close()
|
|
if resp2.StatusCode != http.StatusOK {
|
|
t.Errorf("second request status = %d, want 200", resp2.StatusCode)
|
|
}
|
|
|
|
if r.Active() != 1 {
|
|
t.Errorf("Active() = %d, want 1 (reuse, not spawn)", r.Active())
|
|
}
|
|
if backends.count() != 1 {
|
|
t.Errorf("Spawn called %d times, want 1", backends.count())
|
|
}
|
|
if backends.at(0).requests.Load() != 2 {
|
|
t.Errorf("backend.requests = %d, want 2", backends.at(0).requests.Load())
|
|
}
|
|
}
|
|
|
|
func TestServe_UnknownSid_410(t *testing.T) {
|
|
spawn, _ := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
resp := doReq(t, srv.URL, "definitely-not-a-real-sid", bearerSess("hash-C"), `{}`)
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusGone {
|
|
t.Errorf("status = %d, want 410", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestServe_TokenMismatch_403(t *testing.T) {
|
|
// Two different bearer hashes for the same sid: only the original
|
|
// owner can access.
|
|
spawn, _ := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
first := doReq(t, srv.URL, "", bearerSess("alice"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
first.Body.Close()
|
|
sid := first.Header.Get(session.SessionIDHeader)
|
|
|
|
hijack := doReq(t, srv.URL, sid, bearerSess("eve"), `{"jsonrpc":"2.0","id":2}`)
|
|
hijack.Body.Close()
|
|
if hijack.StatusCode != http.StatusForbidden {
|
|
t.Errorf("status = %d, want 403", hijack.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestServe_NoAuthSessionInContext_500(t *testing.T) {
|
|
// Calling /mcp without going through RequireBearer first is a
|
|
// programmer error; the registry surfaces it loudly rather than
|
|
// silently spawning an unauthenticated session.
|
|
spawn, _ := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
resp, err := http.Post(srv.URL, "application/json", strings.NewReader(`{}`))
|
|
if err != nil {
|
|
t.Fatalf("post: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusInternalServerError {
|
|
t.Errorf("status = %d, want 500", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestServe_MaxSessionsCap(t *testing.T) {
|
|
spawn, _ := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{
|
|
Spawn: spawn, Log: brokerlog.Discard(), MaxSessions: 2,
|
|
})
|
|
srv := newTestServer(t, r)
|
|
|
|
// Two sessions allowed.
|
|
for i, hash := range []string{"a", "b"} {
|
|
resp := doReq(t, srv.URL, "", bearerSess(hash), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("session %d: status = %d", i, resp.StatusCode)
|
|
}
|
|
}
|
|
// Third hits the cap.
|
|
resp := doReq(t, srv.URL, "", bearerSess("c"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
|
t.Errorf("third session status = %d, want 503", resp.StatusCode)
|
|
}
|
|
if resp.Header.Get("Retry-After") == "" {
|
|
t.Error("503 response should include Retry-After")
|
|
}
|
|
if r.Active() != 2 {
|
|
t.Errorf("Active() = %d, want 2", r.Active())
|
|
}
|
|
}
|
|
|
|
func TestServe_BackendDone_RemovesSession(t *testing.T) {
|
|
// When the child exits on its own (Done closes), the registry should
|
|
// reap the entry so subsequent traffic with that sid gets 410.
|
|
spawn, backends := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
bearer := bearerSess("crashed")
|
|
first := doReq(t, srv.URL, "", bearer, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
first.Body.Close()
|
|
sid := first.Header.Get(session.SessionIDHeader)
|
|
|
|
// Simulate the child exiting.
|
|
close(backends.at(0).done)
|
|
|
|
// Wait for the reaper goroutine — poll Active() rather than add a
|
|
// special hook to the production type.
|
|
if !waitForActive(r, 0, 2*time.Second) {
|
|
t.Fatalf("session count never dropped to 0")
|
|
}
|
|
|
|
resp := doReq(t, srv.URL, sid, bearer, `{}`)
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusGone {
|
|
t.Errorf("after backend.Done, status = %d, want 410", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestStop_TearsDownAllSessions(t *testing.T) {
|
|
spawn, backends := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
for _, hash := range []string{"x", "y", "z"} {
|
|
resp := doReq(t, srv.URL, "", bearerSess(hash), `{}`)
|
|
resp.Body.Close()
|
|
}
|
|
if r.Active() != 3 {
|
|
t.Fatalf("Active = %d, want 3", r.Active())
|
|
}
|
|
|
|
r.Stop(context.Background())
|
|
if r.Active() != 0 {
|
|
t.Errorf("Active after Stop = %d, want 0", r.Active())
|
|
}
|
|
for _, b := range backends.snapshot() {
|
|
if !b.stopped.Load() {
|
|
t.Errorf("backend %d not stopped", b.id)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNew_RequiresSpawn(t *testing.T) {
|
|
_, err := session.New(session.Config{Log: brokerlog.Discard()})
|
|
if err == nil || !strings.Contains(err.Error(), "Spawn") {
|
|
t.Errorf("want Spawn-required error, got %v", err)
|
|
}
|
|
}
|
|
|
|
// doReq POSTs to the test server with an optional session id and a
|
|
// pseudo-bearer hash that the test middleware translates into an
|
|
// oauth.Session. Pass an empty hash to simulate an unauthenticated
|
|
// request.
|
|
func doReq(t *testing.T, url, sid string, sess *oauth.Session, body string) *http.Response {
|
|
t.Helper()
|
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, url, strings.NewReader(body))
|
|
if err != nil {
|
|
t.Fatalf("new request: %v", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if sid != "" {
|
|
req.Header.Set(session.SessionIDHeader, sid)
|
|
}
|
|
if sess != nil {
|
|
req.Header.Set(testBearerHeader, sess.BrokerTokenHash)
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("do: %v", err)
|
|
}
|
|
return resp
|
|
}
|
|
|
|
// waitForActive polls r.Active() until it equals target or the deadline
|
|
// expires. Returns true on a hit. Used by the backend-Done test to give
|
|
// the registry's reaper goroutine time to run.
|
|
func waitForActive(r *session.Registry, target int, within time.Duration) bool {
|
|
deadline := time.Now().Add(within)
|
|
for time.Now().Before(deadline) {
|
|
if r.Active() == target {
|
|
return true
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
return false
|
|
}
|