Adds internal/session.Registry, the MCP session glue that maps
Mcp-Session-Id to a running forgejo-mcp child + bridge.
Lifecycle:
- First /mcp POST without Mcp-Session-Id: SpawnFunc creates a backend
(in production: supervisor.Start + bridge.New); registry mints a
192-bit hex session id, attaches it to the response header, and
dispatches the request to the new backend.
- Subsequent POSTs with the header dispatch to the existing backend.
- Unknown sids → 410 Gone (per MCP guidance, so clients re-initialise
instead of retrying forever).
- Sids are bound to the OAuth token that minted them: a different
bearer probing a stolen sid gets 403, distinct from "your token is
bad" (401) and "sid unknown" (410).
Cleanup:
- When backend.Done closes (child exited on its own — crash, OOM,
user-driven shutdown), a goroutine reaps the entry.
- Stop tears every session down on broker shutdown. The 30s idle
reaper and Forgejo token rotation come in 5c.
The Registry is decoupled from supervisor and bridge via SpawnFunc, so
tests don't need to fork real processes — they hand the registry a fake
that returns a controllable Backend. Also added oauth.ContextWithSession
so the session tests can inject an oauth.Session into request contexts
without standing up the full bearer-middleware chain.
Tests: 83.3% coverage. Cover spawn-on-initialize, sid reuse, unknown
sid, max-session cap with Retry-After, no-auth-context guard, sid
hijack defense (token mismatch → 403), Done-channel reaping, and
graceful Stop.
Closes forgejo-mcp-broker-t81.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
345 lines
10 KiB
Go
345 lines
10 KiB
Go
package session_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log"
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/session"
|
|
)
|
|
|
|
// fakeBackend is a controllable session.Backend used by all tests in this
|
|
// package. It records every Handler invocation and exposes a Done channel
|
|
// the tests can close to simulate the child exiting.
|
|
type fakeBackend struct {
|
|
id int
|
|
handler func(w http.ResponseWriter, r *http.Request)
|
|
done chan struct{}
|
|
stopErr error
|
|
stopped atomic.Bool
|
|
requests atomic.Int32
|
|
}
|
|
|
|
func newFakeBackend(id int) *fakeBackend {
|
|
return &fakeBackend{id: id, done: make(chan struct{})}
|
|
}
|
|
|
|
func (f *fakeBackend) backend() *session.Backend {
|
|
return &session.Backend{
|
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
f.requests.Add(1)
|
|
if f.handler != nil {
|
|
f.handler(w, r)
|
|
return
|
|
}
|
|
fmt.Fprintf(w, "backend-%d-served", f.id)
|
|
}),
|
|
Stop: func(ctx context.Context) error {
|
|
f.stopped.Store(true)
|
|
select {
|
|
case <-f.done:
|
|
default:
|
|
close(f.done)
|
|
}
|
|
return f.stopErr
|
|
},
|
|
Done: f.done,
|
|
}
|
|
}
|
|
|
|
// fakeSpawner returns a SpawnFunc that hands out a sequence of fakeBackends.
|
|
// The returned slice is appended to as Spawn is called, so tests can
|
|
// inspect every backend that was minted.
|
|
func fakeSpawner(t *testing.T) (session.SpawnFunc, *[]*fakeBackend) {
|
|
t.Helper()
|
|
var (
|
|
mu sync.Mutex
|
|
backends []*fakeBackend
|
|
next int
|
|
)
|
|
spawn := func(ctx context.Context, sess *oauth.Session) (*session.Backend, error) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
fb := newFakeBackend(next)
|
|
next++
|
|
backends = append(backends, fb)
|
|
return fb.backend(), nil
|
|
}
|
|
return spawn, &backends
|
|
}
|
|
|
|
// testBearerHeader carries a bearer-hash discriminator across the wire so
|
|
// the test server can swap in the right oauth.Session per request. We
|
|
// can't propagate context from client to server through net/http, so this
|
|
// header substitutes for what RequireBearer would otherwise inject.
|
|
const testBearerHeader = "X-Test-Bearer-Hash"
|
|
|
|
// newTestServer wraps the Registry handler with a tiny middleware that
|
|
// reads testBearerHeader and attaches a matching oauth.Session to the
|
|
// request context. Without the header (i.e. simulating no auth), the
|
|
// registry sees a context without a session and returns 500 — exactly
|
|
// the production behaviour for unauthenticated /mcp traffic.
|
|
func newTestServer(t *testing.T, r *session.Registry) *httptest.Server {
|
|
t.Helper()
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
if hash := req.Header.Get(testBearerHeader); hash != "" {
|
|
ctx := oauth.ContextWithSession(req.Context(), bearerSess(hash))
|
|
req = req.WithContext(ctx)
|
|
}
|
|
r.Handler().ServeHTTP(w, req)
|
|
})
|
|
srv := httptest.NewServer(handler)
|
|
t.Cleanup(srv.Close)
|
|
return srv
|
|
}
|
|
|
|
// helper used at the package level so tests don't have to construct
|
|
// oauth.Session manually each time.
|
|
func bearerSess(hash string) *oauth.Session {
|
|
return &oauth.Session{
|
|
ClientID: "client-" + hash,
|
|
ForgejoUsername: "user-" + hash,
|
|
BrokerTokenHash: hash,
|
|
ForgejoToken: "fj-token-" + hash,
|
|
Scopes: "read:user",
|
|
}
|
|
}
|
|
|
|
func TestServe_NewSession_MintsSidAndDispatches(t *testing.T) {
|
|
spawn, backends := fakeSpawner(t)
|
|
r, err := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
if err != nil {
|
|
t.Fatalf("New: %v", err)
|
|
}
|
|
|
|
srv := newTestServer(t, r)
|
|
|
|
resp := doReq(t, srv.URL, "", bearerSess("hash-A"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("status = %d, want 200; body: %s", resp.StatusCode, body)
|
|
}
|
|
sid := resp.Header.Get(session.SessionIDHeader)
|
|
if sid == "" {
|
|
t.Error("response missing Mcp-Session-Id header")
|
|
}
|
|
if r.Active() != 1 {
|
|
t.Errorf("Active() = %d, want 1", r.Active())
|
|
}
|
|
if len(*backends) != 1 || (*backends)[0].requests.Load() != 1 {
|
|
t.Errorf("backend was not invoked exactly once: %+v", *backends)
|
|
}
|
|
}
|
|
|
|
func TestServe_KnownSid_ReusesBackend(t *testing.T) {
|
|
spawn, backends := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
bearer := bearerSess("hash-B")
|
|
|
|
resp1 := doReq(t, srv.URL, "", bearer, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
resp1.Body.Close()
|
|
sid := resp1.Header.Get(session.SessionIDHeader)
|
|
|
|
resp2 := doReq(t, srv.URL, sid, bearer, `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`)
|
|
resp2.Body.Close()
|
|
if resp2.StatusCode != http.StatusOK {
|
|
t.Errorf("second request status = %d, want 200", resp2.StatusCode)
|
|
}
|
|
|
|
if r.Active() != 1 {
|
|
t.Errorf("Active() = %d, want 1 (reuse, not spawn)", r.Active())
|
|
}
|
|
if len(*backends) != 1 {
|
|
t.Errorf("Spawn called %d times, want 1", len(*backends))
|
|
}
|
|
if (*backends)[0].requests.Load() != 2 {
|
|
t.Errorf("backend.requests = %d, want 2", (*backends)[0].requests.Load())
|
|
}
|
|
}
|
|
|
|
func TestServe_UnknownSid_410(t *testing.T) {
|
|
spawn, _ := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
resp := doReq(t, srv.URL, "definitely-not-a-real-sid", bearerSess("hash-C"), `{}`)
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusGone {
|
|
t.Errorf("status = %d, want 410", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestServe_TokenMismatch_403(t *testing.T) {
|
|
// Two different bearer hashes for the same sid: only the original
|
|
// owner can access.
|
|
spawn, _ := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
first := doReq(t, srv.URL, "", bearerSess("alice"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
first.Body.Close()
|
|
sid := first.Header.Get(session.SessionIDHeader)
|
|
|
|
hijack := doReq(t, srv.URL, sid, bearerSess("eve"), `{"jsonrpc":"2.0","id":2}`)
|
|
hijack.Body.Close()
|
|
if hijack.StatusCode != http.StatusForbidden {
|
|
t.Errorf("status = %d, want 403", hijack.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestServe_NoAuthSessionInContext_500(t *testing.T) {
|
|
// Calling /mcp without going through RequireBearer first is a
|
|
// programmer error; the registry surfaces it loudly rather than
|
|
// silently spawning an unauthenticated session.
|
|
spawn, _ := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
resp, err := http.Post(srv.URL, "application/json", strings.NewReader(`{}`))
|
|
if err != nil {
|
|
t.Fatalf("post: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusInternalServerError {
|
|
t.Errorf("status = %d, want 500", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestServe_MaxSessionsCap(t *testing.T) {
|
|
spawn, _ := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{
|
|
Spawn: spawn, Log: brokerlog.Discard(), MaxSessions: 2,
|
|
})
|
|
srv := newTestServer(t, r)
|
|
|
|
// Two sessions allowed.
|
|
for i, hash := range []string{"a", "b"} {
|
|
resp := doReq(t, srv.URL, "", bearerSess(hash), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("session %d: status = %d", i, resp.StatusCode)
|
|
}
|
|
}
|
|
// Third hits the cap.
|
|
resp := doReq(t, srv.URL, "", bearerSess("c"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
|
t.Errorf("third session status = %d, want 503", resp.StatusCode)
|
|
}
|
|
if resp.Header.Get("Retry-After") == "" {
|
|
t.Error("503 response should include Retry-After")
|
|
}
|
|
if r.Active() != 2 {
|
|
t.Errorf("Active() = %d, want 2", r.Active())
|
|
}
|
|
}
|
|
|
|
func TestServe_BackendDone_RemovesSession(t *testing.T) {
|
|
// When the child exits on its own (Done closes), the registry should
|
|
// reap the entry so subsequent traffic with that sid gets 410.
|
|
spawn, backends := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
bearer := bearerSess("crashed")
|
|
first := doReq(t, srv.URL, "", bearer, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
|
first.Body.Close()
|
|
sid := first.Header.Get(session.SessionIDHeader)
|
|
|
|
// Simulate the child exiting.
|
|
close((*backends)[0].done)
|
|
|
|
// Wait for the reaper goroutine — poll Active() rather than add a
|
|
// special hook to the production type.
|
|
if !waitForActive(r, 0, 2*time.Second) {
|
|
t.Fatalf("session count never dropped to 0")
|
|
}
|
|
|
|
resp := doReq(t, srv.URL, sid, bearer, `{}`)
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusGone {
|
|
t.Errorf("after backend.Done, status = %d, want 410", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestStop_TearsDownAllSessions(t *testing.T) {
|
|
spawn, backends := fakeSpawner(t)
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
|
srv := newTestServer(t, r)
|
|
|
|
for _, hash := range []string{"x", "y", "z"} {
|
|
resp := doReq(t, srv.URL, "", bearerSess(hash), `{}`)
|
|
resp.Body.Close()
|
|
}
|
|
if r.Active() != 3 {
|
|
t.Fatalf("Active = %d, want 3", r.Active())
|
|
}
|
|
|
|
r.Stop(context.Background())
|
|
if r.Active() != 0 {
|
|
t.Errorf("Active after Stop = %d, want 0", r.Active())
|
|
}
|
|
for _, b := range *backends {
|
|
if !b.stopped.Load() {
|
|
t.Errorf("backend %d not stopped", b.id)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNew_RequiresSpawn(t *testing.T) {
|
|
_, err := session.New(session.Config{Log: brokerlog.Discard()})
|
|
if err == nil || !strings.Contains(err.Error(), "Spawn") {
|
|
t.Errorf("want Spawn-required error, got %v", err)
|
|
}
|
|
}
|
|
|
|
// doReq POSTs to the test server with an optional session id and a
|
|
// pseudo-bearer hash that the test middleware translates into an
|
|
// oauth.Session. Pass an empty hash to simulate an unauthenticated
|
|
// request.
|
|
func doReq(t *testing.T, url, sid string, sess *oauth.Session, body string) *http.Response {
|
|
t.Helper()
|
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, url, strings.NewReader(body))
|
|
if err != nil {
|
|
t.Fatalf("new request: %v", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if sid != "" {
|
|
req.Header.Set(session.SessionIDHeader, sid)
|
|
}
|
|
if sess != nil {
|
|
req.Header.Set(testBearerHeader, sess.BrokerTokenHash)
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("do: %v", err)
|
|
}
|
|
return resp
|
|
}
|
|
|
|
// waitForActive polls r.Active() until it equals target or the deadline
|
|
// expires. Returns true on a hit. Used by the backend-Done test to give
|
|
// the registry's reaper goroutine time to run.
|
|
func waitForActive(r *session.Registry, target int, within time.Duration) bool {
|
|
deadline := time.Now().Add(within)
|
|
for time.Now().Before(deadline) {
|
|
if r.Active() == target {
|
|
return true
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
return false
|
|
}
|