package session_test import ( "context" "fmt" "io" "net/http" "net/http/httptest" "strings" "sync" "sync/atomic" "testing" "time" brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/session" ) // fakeBackend is a controllable session.Backend used by all tests in this // package. It records every Handler invocation and exposes a Done channel // the tests can close to simulate the child exiting. type fakeBackend struct { id int handler func(w http.ResponseWriter, r *http.Request) done chan struct{} stopErr error stopped atomic.Bool requests atomic.Int32 } func newFakeBackend(id int) *fakeBackend { return &fakeBackend{id: id, done: make(chan struct{})} } func (f *fakeBackend) backend() *session.Backend { return &session.Backend{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { f.requests.Add(1) if f.handler != nil { f.handler(w, r) return } fmt.Fprintf(w, "backend-%d-served", f.id) }), Stop: func(ctx context.Context) error { f.stopped.Store(true) select { case <-f.done: default: close(f.done) } return f.stopErr }, Done: f.done, } } // spawnerControl wraps a fake SpawnFunc with thread-safe access to the // backends it has minted. Tests get raced freely without a control type // because both spawn (called from registry goroutines) and tests access // the slice; this lets tests query under a lock. type spawnerControl struct { mu sync.Mutex backends []*fakeBackend spawn session.SpawnFunc } func newSpawnerControl(t *testing.T) *spawnerControl { t.Helper() c := &spawnerControl{} c.spawn = func(ctx context.Context, sess *oauth.Session) (*session.Backend, error) { c.mu.Lock() defer c.mu.Unlock() fb := newFakeBackend(len(c.backends)) c.backends = append(c.backends, fb) return fb.backend(), nil } return c } func (c *spawnerControl) count() int { c.mu.Lock() defer c.mu.Unlock() return len(c.backends) } func (c *spawnerControl) at(i int) *fakeBackend { c.mu.Lock() defer c.mu.Unlock() if i >= len(c.backends) { return nil } return c.backends[i] } func (c *spawnerControl) snapshot() []*fakeBackend { c.mu.Lock() defer c.mu.Unlock() out := make([]*fakeBackend, len(c.backends)) copy(out, c.backends) return out } // fakeSpawner is the legacy two-return adapter so existing tests keep // compiling. New tests should prefer newSpawnerControl directly. func fakeSpawner(t *testing.T) (session.SpawnFunc, *spawnerControl) { c := newSpawnerControl(t) return c.spawn, c } // testBearerHeader carries a bearer-hash discriminator across the wire so // the test server can swap in the right oauth.Session per request. We // can't propagate context from client to server through net/http, so this // header substitutes for what RequireBearer would otherwise inject. const testBearerHeader = "X-Test-Bearer-Hash" // newTestServer wraps the Registry handler with a tiny middleware that // reads testBearerHeader and attaches a matching oauth.Session to the // request context. Without the header (i.e. simulating no auth), the // registry sees a context without a session and returns 500 — exactly // the production behaviour for unauthenticated /mcp traffic. func newTestServer(t *testing.T, r *session.Registry) *httptest.Server { t.Helper() handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if hash := req.Header.Get(testBearerHeader); hash != "" { ctx := oauth.ContextWithSession(req.Context(), bearerSess(hash)) req = req.WithContext(ctx) } r.Handler().ServeHTTP(w, req) }) srv := httptest.NewServer(handler) t.Cleanup(srv.Close) return srv } // helper used at the package level so tests don't have to construct // oauth.Session manually each time. func bearerSess(hash string) *oauth.Session { return &oauth.Session{ ClientID: "client-" + hash, ForgejoUsername: "user-" + hash, BrokerTokenHash: hash, ForgejoToken: "fj-token-" + hash, Scopes: "read:user", } } func TestServe_NewSession_MintsSidAndDispatches(t *testing.T) { spawn, backends := fakeSpawner(t) r, err := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()}) if err != nil { t.Fatalf("New: %v", err) } srv := newTestServer(t, r) resp := doReq(t, srv.URL, "", bearerSess("hash-A"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) t.Fatalf("status = %d, want 200; body: %s", resp.StatusCode, body) } sid := resp.Header.Get(session.SessionIDHeader) if sid == "" { t.Error("response missing Mcp-Session-Id header") } if r.Active() != 1 { t.Errorf("Active() = %d, want 1", r.Active()) } if backends.count() != 1 || backends.at(0).requests.Load() != 1 { t.Errorf("backend was not invoked exactly once: %+v", backends.snapshot()) } } func TestServe_KnownSid_ReusesBackend(t *testing.T) { spawn, backends := fakeSpawner(t) r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()}) srv := newTestServer(t, r) bearer := bearerSess("hash-B") resp1 := doReq(t, srv.URL, "", bearer, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) resp1.Body.Close() sid := resp1.Header.Get(session.SessionIDHeader) resp2 := doReq(t, srv.URL, sid, bearer, `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`) resp2.Body.Close() if resp2.StatusCode != http.StatusOK { t.Errorf("second request status = %d, want 200", resp2.StatusCode) } if r.Active() != 1 { t.Errorf("Active() = %d, want 1 (reuse, not spawn)", r.Active()) } if backends.count() != 1 { t.Errorf("Spawn called %d times, want 1", backends.count()) } if backends.at(0).requests.Load() != 2 { t.Errorf("backend.requests = %d, want 2", backends.at(0).requests.Load()) } } func TestServe_UnknownSid_410(t *testing.T) { spawn, _ := fakeSpawner(t) r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()}) srv := newTestServer(t, r) resp := doReq(t, srv.URL, "definitely-not-a-real-sid", bearerSess("hash-C"), `{}`) resp.Body.Close() if resp.StatusCode != http.StatusGone { t.Errorf("status = %d, want 410", resp.StatusCode) } } func TestServe_TokenMismatch_403(t *testing.T) { // Two different bearer hashes for the same sid: only the original // owner can access. spawn, _ := fakeSpawner(t) r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()}) srv := newTestServer(t, r) first := doReq(t, srv.URL, "", bearerSess("alice"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) first.Body.Close() sid := first.Header.Get(session.SessionIDHeader) hijack := doReq(t, srv.URL, sid, bearerSess("eve"), `{"jsonrpc":"2.0","id":2}`) hijack.Body.Close() if hijack.StatusCode != http.StatusForbidden { t.Errorf("status = %d, want 403", hijack.StatusCode) } } func TestServe_NoAuthSessionInContext_500(t *testing.T) { // Calling /mcp without going through RequireBearer first is a // programmer error; the registry surfaces it loudly rather than // silently spawning an unauthenticated session. spawn, _ := fakeSpawner(t) r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()}) srv := newTestServer(t, r) resp, err := http.Post(srv.URL, "application/json", strings.NewReader(`{}`)) if err != nil { t.Fatalf("post: %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusInternalServerError { t.Errorf("status = %d, want 500", resp.StatusCode) } } func TestServe_MaxSessionsCap(t *testing.T) { spawn, _ := fakeSpawner(t) r, _ := session.New(session.Config{ Spawn: spawn, Log: brokerlog.Discard(), MaxSessions: 2, }) srv := newTestServer(t, r) // Two sessions allowed. for i, hash := range []string{"a", "b"} { resp := doReq(t, srv.URL, "", bearerSess(hash), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("session %d: status = %d", i, resp.StatusCode) } } // Third hits the cap. resp := doReq(t, srv.URL, "", bearerSess("c"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) resp.Body.Close() if resp.StatusCode != http.StatusServiceUnavailable { t.Errorf("third session status = %d, want 503", resp.StatusCode) } if resp.Header.Get("Retry-After") == "" { t.Error("503 response should include Retry-After") } if r.Active() != 2 { t.Errorf("Active() = %d, want 2", r.Active()) } } func TestServe_BackendDone_RemovesSession(t *testing.T) { // When the child exits on its own (Done closes), the registry should // reap the entry so subsequent traffic with that sid gets 410. spawn, backends := fakeSpawner(t) r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()}) srv := newTestServer(t, r) bearer := bearerSess("crashed") first := doReq(t, srv.URL, "", bearer, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) first.Body.Close() sid := first.Header.Get(session.SessionIDHeader) // Simulate the child exiting. close(backends.at(0).done) // Wait for the reaper goroutine — poll Active() rather than add a // special hook to the production type. if !waitForActive(r, 0, 2*time.Second) { t.Fatalf("session count never dropped to 0") } resp := doReq(t, srv.URL, sid, bearer, `{}`) resp.Body.Close() if resp.StatusCode != http.StatusGone { t.Errorf("after backend.Done, status = %d, want 410", resp.StatusCode) } } func TestStop_TearsDownAllSessions(t *testing.T) { spawn, backends := fakeSpawner(t) r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()}) srv := newTestServer(t, r) for _, hash := range []string{"x", "y", "z"} { resp := doReq(t, srv.URL, "", bearerSess(hash), `{}`) resp.Body.Close() } if r.Active() != 3 { t.Fatalf("Active = %d, want 3", r.Active()) } r.Stop(context.Background()) if r.Active() != 0 { t.Errorf("Active after Stop = %d, want 0", r.Active()) } for _, b := range backends.snapshot() { if !b.stopped.Load() { t.Errorf("backend %d not stopped", b.id) } } } func TestNew_RequiresSpawn(t *testing.T) { _, err := session.New(session.Config{Log: brokerlog.Discard()}) if err == nil || !strings.Contains(err.Error(), "Spawn") { t.Errorf("want Spawn-required error, got %v", err) } } // doReq POSTs to the test server with an optional session id and a // pseudo-bearer hash that the test middleware translates into an // oauth.Session. Pass an empty hash to simulate an unauthenticated // request. func doReq(t *testing.T, url, sid string, sess *oauth.Session, body string) *http.Response { t.Helper() req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, url, strings.NewReader(body)) if err != nil { t.Fatalf("new request: %v", err) } req.Header.Set("Content-Type", "application/json") if sid != "" { req.Header.Set(session.SessionIDHeader, sid) } if sess != nil { req.Header.Set(testBearerHeader, sess.BrokerTokenHash) } resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } return resp } // waitForActive polls r.Active() until it equals target or the deadline // expires. Returns true on a hit. Used by the backend-Done test to give // the registry's reaper goroutine time to run. func waitForActive(r *session.Registry, target int, within time.Duration) bool { deadline := time.Now().Add(within) for time.Now().Before(deadline) { if r.Active() == target { return true } time.Sleep(5 * time.Millisecond) } return false }