forgejo-mcp-broker/internal/session/session.go
Ole-Morten Duesund 886092a600 feat(session): MCP session registry + spawn-on-initialize (forgejo-mcp-broker-t81)
Adds internal/session.Registry, the MCP session glue that maps
Mcp-Session-Id to a running forgejo-mcp child + bridge.

Lifecycle:
  - First /mcp POST without Mcp-Session-Id: SpawnFunc creates a backend
    (in production: supervisor.Start + bridge.New); registry mints a
    192-bit hex session id, attaches it to the response header, and
    dispatches the request to the new backend.
  - Subsequent POSTs with the header dispatch to the existing backend.
  - Unknown sids → 410 Gone (per MCP guidance, so clients re-initialise
    instead of retrying forever).
  - Sids are bound to the OAuth token that minted them: a different
    bearer probing a stolen sid gets 403, distinct from "your token is
    bad" (401) and "sid unknown" (410).

Cleanup:
  - When backend.Done closes (child exited on its own — crash, OOM,
    user-driven shutdown), a goroutine reaps the entry.
  - Stop tears every session down on broker shutdown. The 30s idle
    reaper and Forgejo token rotation come in 5c.

The Registry is decoupled from supervisor and bridge via SpawnFunc, so
tests don't need to fork real processes — they hand the registry a fake
that returns a controllable Backend. Also added oauth.ContextWithSession
so the session tests can inject an oauth.Session into request contexts
without standing up the full bearer-middleware chain.

Tests: 83.3% coverage. Cover spawn-on-initialize, sid reuse, unknown
sid, max-session cap with Retry-After, no-auth-context guard, sid
hijack defense (token mismatch → 403), Done-channel reaping, and
graceful Stop.

Closes forgejo-mcp-broker-t81.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 17:24:25 +02:00

235 lines
7.1 KiB
Go

// Package session is the broker's MCP session glue. It maps the
// `Mcp-Session-Id` header onto a running forgejo-mcp subprocess (managed
// by internal/supervisor) plus a bridge (internal/bridge) that pipes
// JSON-RPC traffic.
//
// One Registry handles all sessions. New session ids are minted on the
// first /mcp POST that arrives without a session header; subsequent
// requests with that header are dispatched to the same backend so a
// single user keeps the same forgejo-mcp child for the life of the
// session.
//
// The Registry knows how to spawn — it does not know how. Phase-5a tests
// inject fake SpawnFuncs to exercise the lifecycle without forking real
// processes. A production wiring lives in cmd/broker (phase 5c will
// finalise that).
package session
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"log/slog"
"net/http"
"sync"
"sync/atomic"
"time"
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
)
// Backend is the runtime view of one forgejo-mcp subprocess plus its
// bridge. The Registry calls Handler.ServeHTTP for each /mcp request
// belonging to the session and Stop when the session is reaped.
type Backend struct {
Handler http.Handler
Stop func(ctx context.Context) error
Done <-chan struct{}
}
// SpawnFunc constructs a Backend for the supplied OAuth session. The
// production implementation spawns forgejo-mcp via supervisor and wires
// a bridge; tests pass fakes.
type SpawnFunc func(ctx context.Context, sess *oauth.Session) (*Backend, error)
// Config bundles the inputs to New.
type Config struct {
Spawn SpawnFunc
MaxSessions int // 0 means unlimited
Log *slog.Logger // optional; defaults to discard
Now func() time.Time // optional; defaults to time.Now
}
// Registry tracks active sessions, dispatches requests to them, and tears
// them down on broker shutdown. Construct via New — the embedded sync.Map
// makes value copies unsafe.
type Registry struct {
spawn SpawnFunc
maxSessions int
log *slog.Logger
now func() time.Time
sessions sync.Map // string sid → *entry
count atomic.Int32
}
type entry struct {
sid string
backend *Backend
lastActive atomic.Int64 // unix nanoseconds; bumped per request
oauthSess *oauth.Session
}
// SessionIDHeader is the streamable-HTTP MCP header that ferries the
// session id between client and server.
const SessionIDHeader = "Mcp-Session-Id"
// New validates the config and returns a Registry ready to serve.
func New(cfg Config) (*Registry, error) {
if cfg.Spawn == nil {
return nil, errors.New("session: Spawn is required")
}
log := cfg.Log
if log == nil {
log = slog.New(slog.DiscardHandler)
}
now := cfg.Now
if now == nil {
now = time.Now
}
return &Registry{
spawn: cfg.Spawn,
maxSessions: cfg.MaxSessions,
log: log,
now: now,
}, nil
}
// Handler returns an http.Handler suitable for mounting at /mcp. Wrap it
// with oauth.Authenticator.RequireBearer so every request carries a
// validated Session in its context.
func (r *Registry) Handler() http.Handler {
return http.HandlerFunc(r.serve)
}
func (r *Registry) serve(w http.ResponseWriter, req *http.Request) {
oauthSess, ok := oauth.SessionFromContext(req.Context())
if !ok {
// Mounted without RequireBearer; treat as a programmer error.
http.Error(w, "no auth session in context", http.StatusInternalServerError)
return
}
sid := req.Header.Get(SessionIDHeader)
if sid == "" {
// No session id yet — this must be an `initialize`. Mint a session.
e, err := r.spawnSession(req.Context(), oauthSess)
if err != nil {
r.respondSpawnError(w, err)
return
}
w.Header().Set(SessionIDHeader, e.sid)
e.lastActive.Store(r.now().UnixNano())
e.backend.Handler.ServeHTTP(w, req)
return
}
// Lookup; reject unknown session ids with 410 Gone so the client
// re-initialises rather than retrying forever (per MCP guidance).
v, ok := r.sessions.Load(sid)
if !ok {
http.Error(w, "unknown or expired session", http.StatusGone)
return
}
e := v.(*entry)
if e.oauthSess.BrokerTokenHash != oauthSess.BrokerTokenHash {
// Session id is bound to the OAuth token that minted it. A
// different bearer probing a stolen sid gets 403 — not 401, so
// this is distinct from "your token is bad" and from "we don't
// know that sid" (410 above). Defence in depth.
http.Error(w, "session bound to a different token", http.StatusForbidden)
return
}
e.lastActive.Store(r.now().UnixNano())
e.backend.Handler.ServeHTTP(w, req)
}
func (r *Registry) spawnSession(ctx context.Context, oauthSess *oauth.Session) (*entry, error) {
if r.maxSessions > 0 && int(r.count.Load()) >= r.maxSessions {
return nil, errMaxSessions
}
backend, err := r.spawn(ctx, oauthSess)
if err != nil {
return nil, err
}
sid := newSessionID()
e := &entry{sid: sid, backend: backend, oauthSess: oauthSess}
e.lastActive.Store(r.now().UnixNano())
if _, loaded := r.sessions.LoadOrStore(sid, e); loaded {
// Astronomically unlikely (24-byte random collision); roll back.
_ = backend.Stop(ctx)
return nil, errors.New("session: id collision")
}
r.count.Add(1)
// When the child exits on its own (crash, OOM, etc.), reap the entry.
go func() {
<-backend.Done
r.removeSession(sid)
}()
return e, nil
}
func (r *Registry) removeSession(sid string) {
if _, ok := r.sessions.LoadAndDelete(sid); ok {
r.count.Add(-1)
}
}
func (r *Registry) respondSpawnError(w http.ResponseWriter, err error) {
if errors.Is(err, errMaxSessions) {
w.Header().Set("Retry-After", "30")
http.Error(w, "broker at max sessions", http.StatusServiceUnavailable)
return
}
r.log.Error("session spawn failed", slog.String("err", err.Error()))
http.Error(w, "session spawn failed", http.StatusInternalServerError)
}
// Active returns the number of currently-tracked sessions. Mostly for
// tests and metrics.
func (r *Registry) Active() int { return int(r.count.Load()) }
// Stop tears down every active session. Used at broker shutdown so
// children don't leak past the parent. Best-effort: stop errors are
// logged but not aggregated.
func (r *Registry) Stop(ctx context.Context) {
r.sessions.Range(func(k, v any) bool {
e := v.(*entry)
if err := e.backend.Stop(ctx); err != nil {
r.log.Warn("session stop", slog.String("sid", e.sid), slog.String("err", err.Error()))
}
r.sessions.Delete(k)
r.count.Add(-1)
return true
})
}
// LastActive returns the last-active wall-clock time for sid, or the
// zero time if no such session exists. Phase 5c's reaper uses this to
// evict idle sessions.
func (r *Registry) LastActive(sid string) time.Time {
v, ok := r.sessions.Load(sid)
if !ok {
return time.Time{}
}
return time.Unix(0, v.(*entry).lastActive.Load())
}
// errMaxSessions signals the cap was hit. Internal sentinel only —
// callers see a 503.
var errMaxSessions = errors.New("session: max sessions reached")
// newSessionID returns a hex-encoded 24-byte random id (48 hex chars,
// 192 bits of entropy). Plenty for global uniqueness across a broker.
func newSessionID() string {
b := make([]byte, 24)
if _, err := rand.Read(b); err != nil {
panic("session: crypto/rand failed: " + err.Error())
}
return hex.EncodeToString(b)
}