// Package session is the broker's MCP session glue. It maps the // `Mcp-Session-Id` header onto a running forgejo-mcp subprocess (managed // by internal/supervisor) plus a bridge (internal/bridge) that pipes // JSON-RPC traffic. // // One Registry handles all sessions. New session ids are minted on the // first /mcp POST that arrives without a session header; subsequent // requests with that header are dispatched to the same backend so a // single user keeps the same forgejo-mcp child for the life of the // session. // // The Registry knows how to spawn — it does not know how. Phase-5a tests // inject fake SpawnFuncs to exercise the lifecycle without forking real // processes. A production wiring lives in cmd/broker (phase 5c will // finalise that). package session import ( "context" "crypto/rand" "encoding/hex" "errors" "log/slog" "net/http" "sync" "sync/atomic" "time" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth" ) // Backend is the runtime view of one forgejo-mcp subprocess plus its // bridge. The Registry calls Handler.ServeHTTP for each /mcp request // belonging to the session and Stop when the session is reaped. type Backend struct { Handler http.Handler Stop func(ctx context.Context) error Done <-chan struct{} } // SpawnFunc constructs a Backend for the supplied OAuth session. The // production implementation spawns forgejo-mcp via supervisor and wires // a bridge; tests pass fakes. type SpawnFunc func(ctx context.Context, sess *oauth.Session) (*Backend, error) // Config bundles the inputs to New. type Config struct { Spawn SpawnFunc MaxSessions int // 0 means unlimited Log *slog.Logger // optional; defaults to discard Now func() time.Time // optional; defaults to time.Now } // Registry tracks active sessions, dispatches requests to them, and tears // them down on broker shutdown. Construct via New — the embedded sync.Map // makes value copies unsafe. type Registry struct { spawn SpawnFunc maxSessions int log *slog.Logger now func() time.Time sessions sync.Map // string sid → *entry count atomic.Int32 } type entry struct { sid string backend atomic.Pointer[Backend] // swapped on rotation; readers use Load lastActive atomic.Int64 // unix nanoseconds; bumped per request mu sync.Mutex // guards oauthSess; backend swap holds this too oauthSess *oauth.Session } // SessionIDHeader is the streamable-HTTP MCP header that ferries the // session id between client and server. const SessionIDHeader = "Mcp-Session-Id" // New validates the config and returns a Registry ready to serve. func New(cfg Config) (*Registry, error) { if cfg.Spawn == nil { return nil, errors.New("session: Spawn is required") } log := cfg.Log if log == nil { log = slog.New(slog.DiscardHandler) } now := cfg.Now if now == nil { now = time.Now } return &Registry{ spawn: cfg.Spawn, maxSessions: cfg.MaxSessions, log: log, now: now, }, nil } // Handler returns an http.Handler suitable for mounting at /mcp. Wrap it // with oauth.Authenticator.RequireBearer so every request carries a // validated Session in its context. func (r *Registry) Handler() http.Handler { return http.HandlerFunc(r.serve) } func (r *Registry) serve(w http.ResponseWriter, req *http.Request) { oauthSess, ok := oauth.SessionFromContext(req.Context()) if !ok { // Mounted without RequireBearer; treat as a programmer error. http.Error(w, "no auth session in context", http.StatusInternalServerError) return } sid := req.Header.Get(SessionIDHeader) if sid == "" { // No session id yet — this must be an `initialize`. Mint a session. e, err := r.spawnSession(req.Context(), oauthSess) if err != nil { r.respondSpawnError(w, err) return } w.Header().Set(SessionIDHeader, e.sid) e.lastActive.Store(r.now().UnixNano()) e.backend.Load().Handler.ServeHTTP(w, req) return } // Lookup; reject unknown session ids with 410 Gone so the client // re-initialises rather than retrying forever (per MCP guidance). v, ok := r.sessions.Load(sid) if !ok { http.Error(w, "unknown or expired session", http.StatusGone) return } e := v.(*entry) if e.snapshotOAuth().BrokerTokenHash != oauthSess.BrokerTokenHash { // Session id is bound to the OAuth token that minted it. A // different bearer probing a stolen sid gets 403 — not 401, so // this is distinct from "your token is bad" and from "we don't // know that sid" (410 above). Defence in depth. http.Error(w, "session bound to a different token", http.StatusForbidden) return } e.lastActive.Store(r.now().UnixNano()) e.backend.Load().Handler.ServeHTTP(w, req) } // snapshotOAuth returns a pointer to the entry's current oauthSess under // lock so callers don't see partial swaps during rotation. func (e *entry) snapshotOAuth() *oauth.Session { e.mu.Lock() defer e.mu.Unlock() return e.oauthSess } func (r *Registry) spawnSession(ctx context.Context, oauthSess *oauth.Session) (*entry, error) { if r.maxSessions > 0 && int(r.count.Load()) >= r.maxSessions { return nil, errMaxSessions } backend, err := r.spawn(ctx, oauthSess) if err != nil { return nil, err } sid := newSessionID() e := &entry{sid: sid, oauthSess: oauthSess} e.backend.Store(backend) e.lastActive.Store(r.now().UnixNano()) if _, loaded := r.sessions.LoadOrStore(sid, e); loaded { // Astronomically unlikely (24-byte random collision); roll back. _ = backend.Stop(ctx) return nil, errors.New("session: id collision") } r.count.Add(1) r.watchBackend(sid, backend) return e, nil } // watchBackend launches a goroutine that removes the session if the given // backend's Done closes WHILE that backend is still the entry's current // one. After a rotation, the old backend's Done eventually closes too, // but the entry now points at a new backend; in that case the watcher // is a no-op so the session survives the rotation. func (r *Registry) watchBackend(sid string, backend *Backend) { go func() { <-backend.Done v, ok := r.sessions.Load(sid) if !ok { return } e := v.(*entry) if e.backend.Load() == backend { r.removeSession(sid) } }() } func (r *Registry) removeSession(sid string) { if _, ok := r.sessions.LoadAndDelete(sid); ok { r.count.Add(-1) } } func (r *Registry) respondSpawnError(w http.ResponseWriter, err error) { if errors.Is(err, errMaxSessions) { w.Header().Set("Retry-After", "30") http.Error(w, "broker at max sessions", http.StatusServiceUnavailable) return } r.log.Error("session spawn failed", slog.String("err", err.Error())) http.Error(w, "session spawn failed", http.StatusInternalServerError) } // Active returns the number of currently-tracked sessions. Mostly for // tests and metrics. func (r *Registry) Active() int { return int(r.count.Load()) } // Stop tears down every active session. Used at broker shutdown so // children don't leak past the parent. Best-effort: stop errors are // logged but not aggregated. func (r *Registry) Stop(ctx context.Context) { r.sessions.Range(func(k, v any) bool { e := v.(*entry) if err := e.backend.Load().Stop(ctx); err != nil { r.log.Warn("session stop", slog.String("sid", e.sid), slog.String("err", err.Error())) } r.sessions.Delete(k) r.count.Add(-1) return true }) } // LastActive returns the last-active wall-clock time for sid, or the // zero time if no such session exists. Phase 5c's reaper uses this to // evict idle sessions. func (r *Registry) LastActive(sid string) time.Time { v, ok := r.sessions.Load(sid) if !ok { return time.Time{} } return time.Unix(0, v.(*entry).lastActive.Load()) } // errMaxSessions signals the cap was hit. Internal sentinel only — // callers see a 503. var errMaxSessions = errors.New("session: max sessions reached") // newSessionID returns a hex-encoded 24-byte random id (48 hex chars, // 192 bits of entropy). Plenty for global uniqueness across a broker. func newSessionID() string { b := make([]byte, 24) if _, err := rand.Read(b); err != nil { panic("session: crypto/rand failed: " + err.Error()) } return hex.EncodeToString(b) }