345 lines
10 KiB
Go
345 lines
10 KiB
Go
|
|
package session_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"strings"
|
||
|
|
"sync"
|
||
|
|
"sync/atomic"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log"
|
||
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
|
||
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/session"
|
||
|
|
)
|
||
|
|
|
||
|
|
// fakeBackend is a controllable session.Backend used by all tests in this
|
||
|
|
// package. It records every Handler invocation and exposes a Done channel
|
||
|
|
// the tests can close to simulate the child exiting.
|
||
|
|
type fakeBackend struct {
|
||
|
|
id int
|
||
|
|
handler func(w http.ResponseWriter, r *http.Request)
|
||
|
|
done chan struct{}
|
||
|
|
stopErr error
|
||
|
|
stopped atomic.Bool
|
||
|
|
requests atomic.Int32
|
||
|
|
}
|
||
|
|
|
||
|
|
func newFakeBackend(id int) *fakeBackend {
|
||
|
|
return &fakeBackend{id: id, done: make(chan struct{})}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (f *fakeBackend) backend() *session.Backend {
|
||
|
|
return &session.Backend{
|
||
|
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
f.requests.Add(1)
|
||
|
|
if f.handler != nil {
|
||
|
|
f.handler(w, r)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
fmt.Fprintf(w, "backend-%d-served", f.id)
|
||
|
|
}),
|
||
|
|
Stop: func(ctx context.Context) error {
|
||
|
|
f.stopped.Store(true)
|
||
|
|
select {
|
||
|
|
case <-f.done:
|
||
|
|
default:
|
||
|
|
close(f.done)
|
||
|
|
}
|
||
|
|
return f.stopErr
|
||
|
|
},
|
||
|
|
Done: f.done,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// fakeSpawner returns a SpawnFunc that hands out a sequence of fakeBackends.
|
||
|
|
// The returned slice is appended to as Spawn is called, so tests can
|
||
|
|
// inspect every backend that was minted.
|
||
|
|
func fakeSpawner(t *testing.T) (session.SpawnFunc, *[]*fakeBackend) {
|
||
|
|
t.Helper()
|
||
|
|
var (
|
||
|
|
mu sync.Mutex
|
||
|
|
backends []*fakeBackend
|
||
|
|
next int
|
||
|
|
)
|
||
|
|
spawn := func(ctx context.Context, sess *oauth.Session) (*session.Backend, error) {
|
||
|
|
mu.Lock()
|
||
|
|
defer mu.Unlock()
|
||
|
|
fb := newFakeBackend(next)
|
||
|
|
next++
|
||
|
|
backends = append(backends, fb)
|
||
|
|
return fb.backend(), nil
|
||
|
|
}
|
||
|
|
return spawn, &backends
|
||
|
|
}
|
||
|
|
|
||
|
|
// testBearerHeader carries a bearer-hash discriminator across the wire so
|
||
|
|
// the test server can swap in the right oauth.Session per request. We
|
||
|
|
// can't propagate context from client to server through net/http, so this
|
||
|
|
// header substitutes for what RequireBearer would otherwise inject.
|
||
|
|
const testBearerHeader = "X-Test-Bearer-Hash"
|
||
|
|
|
||
|
|
// newTestServer wraps the Registry handler with a tiny middleware that
|
||
|
|
// reads testBearerHeader and attaches a matching oauth.Session to the
|
||
|
|
// request context. Without the header (i.e. simulating no auth), the
|
||
|
|
// registry sees a context without a session and returns 500 — exactly
|
||
|
|
// the production behaviour for unauthenticated /mcp traffic.
|
||
|
|
func newTestServer(t *testing.T, r *session.Registry) *httptest.Server {
|
||
|
|
t.Helper()
|
||
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||
|
|
if hash := req.Header.Get(testBearerHeader); hash != "" {
|
||
|
|
ctx := oauth.ContextWithSession(req.Context(), bearerSess(hash))
|
||
|
|
req = req.WithContext(ctx)
|
||
|
|
}
|
||
|
|
r.Handler().ServeHTTP(w, req)
|
||
|
|
})
|
||
|
|
srv := httptest.NewServer(handler)
|
||
|
|
t.Cleanup(srv.Close)
|
||
|
|
return srv
|
||
|
|
}
|
||
|
|
|
||
|
|
// helper used at the package level so tests don't have to construct
|
||
|
|
// oauth.Session manually each time.
|
||
|
|
func bearerSess(hash string) *oauth.Session {
|
||
|
|
return &oauth.Session{
|
||
|
|
ClientID: "client-" + hash,
|
||
|
|
ForgejoUsername: "user-" + hash,
|
||
|
|
BrokerTokenHash: hash,
|
||
|
|
ForgejoToken: "fj-token-" + hash,
|
||
|
|
Scopes: "read:user",
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServe_NewSession_MintsSidAndDispatches(t *testing.T) {
|
||
|
|
spawn, backends := fakeSpawner(t)
|
||
|
|
r, err := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("New: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
srv := newTestServer(t, r)
|
||
|
|
|
||
|
|
resp := doReq(t, srv.URL, "", bearerSess("hash-A"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
||
|
|
defer resp.Body.Close()
|
||
|
|
|
||
|
|
if resp.StatusCode != http.StatusOK {
|
||
|
|
body, _ := io.ReadAll(resp.Body)
|
||
|
|
t.Fatalf("status = %d, want 200; body: %s", resp.StatusCode, body)
|
||
|
|
}
|
||
|
|
sid := resp.Header.Get(session.SessionIDHeader)
|
||
|
|
if sid == "" {
|
||
|
|
t.Error("response missing Mcp-Session-Id header")
|
||
|
|
}
|
||
|
|
if r.Active() != 1 {
|
||
|
|
t.Errorf("Active() = %d, want 1", r.Active())
|
||
|
|
}
|
||
|
|
if len(*backends) != 1 || (*backends)[0].requests.Load() != 1 {
|
||
|
|
t.Errorf("backend was not invoked exactly once: %+v", *backends)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServe_KnownSid_ReusesBackend(t *testing.T) {
|
||
|
|
spawn, backends := fakeSpawner(t)
|
||
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
||
|
|
srv := newTestServer(t, r)
|
||
|
|
|
||
|
|
bearer := bearerSess("hash-B")
|
||
|
|
|
||
|
|
resp1 := doReq(t, srv.URL, "", bearer, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
||
|
|
resp1.Body.Close()
|
||
|
|
sid := resp1.Header.Get(session.SessionIDHeader)
|
||
|
|
|
||
|
|
resp2 := doReq(t, srv.URL, sid, bearer, `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`)
|
||
|
|
resp2.Body.Close()
|
||
|
|
if resp2.StatusCode != http.StatusOK {
|
||
|
|
t.Errorf("second request status = %d, want 200", resp2.StatusCode)
|
||
|
|
}
|
||
|
|
|
||
|
|
if r.Active() != 1 {
|
||
|
|
t.Errorf("Active() = %d, want 1 (reuse, not spawn)", r.Active())
|
||
|
|
}
|
||
|
|
if len(*backends) != 1 {
|
||
|
|
t.Errorf("Spawn called %d times, want 1", len(*backends))
|
||
|
|
}
|
||
|
|
if (*backends)[0].requests.Load() != 2 {
|
||
|
|
t.Errorf("backend.requests = %d, want 2", (*backends)[0].requests.Load())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServe_UnknownSid_410(t *testing.T) {
|
||
|
|
spawn, _ := fakeSpawner(t)
|
||
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
||
|
|
srv := newTestServer(t, r)
|
||
|
|
|
||
|
|
resp := doReq(t, srv.URL, "definitely-not-a-real-sid", bearerSess("hash-C"), `{}`)
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusGone {
|
||
|
|
t.Errorf("status = %d, want 410", resp.StatusCode)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServe_TokenMismatch_403(t *testing.T) {
|
||
|
|
// Two different bearer hashes for the same sid: only the original
|
||
|
|
// owner can access.
|
||
|
|
spawn, _ := fakeSpawner(t)
|
||
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
||
|
|
srv := newTestServer(t, r)
|
||
|
|
|
||
|
|
first := doReq(t, srv.URL, "", bearerSess("alice"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
||
|
|
first.Body.Close()
|
||
|
|
sid := first.Header.Get(session.SessionIDHeader)
|
||
|
|
|
||
|
|
hijack := doReq(t, srv.URL, sid, bearerSess("eve"), `{"jsonrpc":"2.0","id":2}`)
|
||
|
|
hijack.Body.Close()
|
||
|
|
if hijack.StatusCode != http.StatusForbidden {
|
||
|
|
t.Errorf("status = %d, want 403", hijack.StatusCode)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServe_NoAuthSessionInContext_500(t *testing.T) {
|
||
|
|
// Calling /mcp without going through RequireBearer first is a
|
||
|
|
// programmer error; the registry surfaces it loudly rather than
|
||
|
|
// silently spawning an unauthenticated session.
|
||
|
|
spawn, _ := fakeSpawner(t)
|
||
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
||
|
|
srv := newTestServer(t, r)
|
||
|
|
|
||
|
|
resp, err := http.Post(srv.URL, "application/json", strings.NewReader(`{}`))
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("post: %v", err)
|
||
|
|
}
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusInternalServerError {
|
||
|
|
t.Errorf("status = %d, want 500", resp.StatusCode)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServe_MaxSessionsCap(t *testing.T) {
|
||
|
|
spawn, _ := fakeSpawner(t)
|
||
|
|
r, _ := session.New(session.Config{
|
||
|
|
Spawn: spawn, Log: brokerlog.Discard(), MaxSessions: 2,
|
||
|
|
})
|
||
|
|
srv := newTestServer(t, r)
|
||
|
|
|
||
|
|
// Two sessions allowed.
|
||
|
|
for i, hash := range []string{"a", "b"} {
|
||
|
|
resp := doReq(t, srv.URL, "", bearerSess(hash), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusOK {
|
||
|
|
t.Fatalf("session %d: status = %d", i, resp.StatusCode)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
// Third hits the cap.
|
||
|
|
resp := doReq(t, srv.URL, "", bearerSess("c"), `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||
|
|
t.Errorf("third session status = %d, want 503", resp.StatusCode)
|
||
|
|
}
|
||
|
|
if resp.Header.Get("Retry-After") == "" {
|
||
|
|
t.Error("503 response should include Retry-After")
|
||
|
|
}
|
||
|
|
if r.Active() != 2 {
|
||
|
|
t.Errorf("Active() = %d, want 2", r.Active())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServe_BackendDone_RemovesSession(t *testing.T) {
|
||
|
|
// When the child exits on its own (Done closes), the registry should
|
||
|
|
// reap the entry so subsequent traffic with that sid gets 410.
|
||
|
|
spawn, backends := fakeSpawner(t)
|
||
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
||
|
|
srv := newTestServer(t, r)
|
||
|
|
|
||
|
|
bearer := bearerSess("crashed")
|
||
|
|
first := doReq(t, srv.URL, "", bearer, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
|
||
|
|
first.Body.Close()
|
||
|
|
sid := first.Header.Get(session.SessionIDHeader)
|
||
|
|
|
||
|
|
// Simulate the child exiting.
|
||
|
|
close((*backends)[0].done)
|
||
|
|
|
||
|
|
// Wait for the reaper goroutine — poll Active() rather than add a
|
||
|
|
// special hook to the production type.
|
||
|
|
if !waitForActive(r, 0, 2*time.Second) {
|
||
|
|
t.Fatalf("session count never dropped to 0")
|
||
|
|
}
|
||
|
|
|
||
|
|
resp := doReq(t, srv.URL, sid, bearer, `{}`)
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusGone {
|
||
|
|
t.Errorf("after backend.Done, status = %d, want 410", resp.StatusCode)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStop_TearsDownAllSessions(t *testing.T) {
|
||
|
|
spawn, backends := fakeSpawner(t)
|
||
|
|
r, _ := session.New(session.Config{Spawn: spawn, Log: brokerlog.Discard()})
|
||
|
|
srv := newTestServer(t, r)
|
||
|
|
|
||
|
|
for _, hash := range []string{"x", "y", "z"} {
|
||
|
|
resp := doReq(t, srv.URL, "", bearerSess(hash), `{}`)
|
||
|
|
resp.Body.Close()
|
||
|
|
}
|
||
|
|
if r.Active() != 3 {
|
||
|
|
t.Fatalf("Active = %d, want 3", r.Active())
|
||
|
|
}
|
||
|
|
|
||
|
|
r.Stop(context.Background())
|
||
|
|
if r.Active() != 0 {
|
||
|
|
t.Errorf("Active after Stop = %d, want 0", r.Active())
|
||
|
|
}
|
||
|
|
for _, b := range *backends {
|
||
|
|
if !b.stopped.Load() {
|
||
|
|
t.Errorf("backend %d not stopped", b.id)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestNew_RequiresSpawn(t *testing.T) {
|
||
|
|
_, err := session.New(session.Config{Log: brokerlog.Discard()})
|
||
|
|
if err == nil || !strings.Contains(err.Error(), "Spawn") {
|
||
|
|
t.Errorf("want Spawn-required error, got %v", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// doReq POSTs to the test server with an optional session id and a
|
||
|
|
// pseudo-bearer hash that the test middleware translates into an
|
||
|
|
// oauth.Session. Pass an empty hash to simulate an unauthenticated
|
||
|
|
// request.
|
||
|
|
func doReq(t *testing.T, url, sid string, sess *oauth.Session, body string) *http.Response {
|
||
|
|
t.Helper()
|
||
|
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, url, strings.NewReader(body))
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("new request: %v", err)
|
||
|
|
}
|
||
|
|
req.Header.Set("Content-Type", "application/json")
|
||
|
|
if sid != "" {
|
||
|
|
req.Header.Set(session.SessionIDHeader, sid)
|
||
|
|
}
|
||
|
|
if sess != nil {
|
||
|
|
req.Header.Set(testBearerHeader, sess.BrokerTokenHash)
|
||
|
|
}
|
||
|
|
resp, err := http.DefaultClient.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("do: %v", err)
|
||
|
|
}
|
||
|
|
return resp
|
||
|
|
}
|
||
|
|
|
||
|
|
// waitForActive polls r.Active() until it equals target or the deadline
|
||
|
|
// expires. Returns true on a hit. Used by the backend-Done test to give
|
||
|
|
// the registry's reaper goroutine time to run.
|
||
|
|
func waitForActive(r *session.Registry, target int, within time.Duration) bool {
|
||
|
|
deadline := time.Now().Add(within)
|
||
|
|
for time.Now().Before(deadline) {
|
||
|
|
if r.Active() == target {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
time.Sleep(5 * time.Millisecond)
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|