Adds StartReaper to internal/session — two background goroutines that
keep the session map healthy under steady load.
Idle reaper:
- Sweeps every ReapInterval (default 30s) for sessions whose
LastActive is older than IdleTimeout (default 15m).
- Evicts via SIGTERM through the Backend.Stop hook.
Token rotator:
- Sweeps every RotateInterval (default 1m) for sessions whose Forgejo
token is within RefreshLead (default 5m) of expiry.
- Calls the operator-supplied RefreshForgejo to obtain new
access+refresh tokens, then Respawn to mint a new Backend with the
updated token in env.
- Atomically swaps e.backend (now an atomic.Pointer[Backend]); the
sid is preserved so the client just re-issues an MCP `initialize`
on its next request rather than re-authenticating.
- On refresh failure, evicts so the next /mcp produces a clean
re-auth instead of carrying a stale token.
Two race fixes uncovered by -race during this work:
- The Done-watcher started in spawnSession captured the original
backend pointer; after rotation it still saw Done close (because
the old backend was Stopped) and would yank the entire entry. Fixed
by comparing watched-backend == e.backend.Load() before evicting.
- The fakeSpawner test helper let tests read the backends slice
without the lock the spawn callback held. Replaced with a
spawnerControl type whose count/at/snapshot methods all lock.
Tests cover idle eviction, recently-active sessions surviving sweeps,
successful rotation+respawn (sid preserved), refresh failure → eviction,
and Stop idempotency.
Closes forgejo-mcp-broker-q4x. Phase 5 complete.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
261 lines
8 KiB
Go
261 lines
8 KiB
Go
// Package session is the broker's MCP session glue. It maps the
|
|
// `Mcp-Session-Id` header onto a running forgejo-mcp subprocess (managed
|
|
// by internal/supervisor) plus a bridge (internal/bridge) that pipes
|
|
// JSON-RPC traffic.
|
|
//
|
|
// One Registry handles all sessions. New session ids are minted on the
|
|
// first /mcp POST that arrives without a session header; subsequent
|
|
// requests with that header are dispatched to the same backend so a
|
|
// single user keeps the same forgejo-mcp child for the life of the
|
|
// session.
|
|
//
|
|
// The Registry knows how to spawn — it does not know how. Phase-5a tests
|
|
// inject fake SpawnFuncs to exercise the lifecycle without forking real
|
|
// processes. A production wiring lives in cmd/broker (phase 5c will
|
|
// finalise that).
|
|
package session
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"log/slog"
|
|
"net/http"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
|
|
)
|
|
|
|
// Backend is the runtime view of one forgejo-mcp subprocess plus its
|
|
// bridge. The Registry calls Handler.ServeHTTP for each /mcp request
|
|
// belonging to the session and Stop when the session is reaped.
|
|
type Backend struct {
|
|
Handler http.Handler
|
|
Stop func(ctx context.Context) error
|
|
Done <-chan struct{}
|
|
}
|
|
|
|
// SpawnFunc constructs a Backend for the supplied OAuth session. The
|
|
// production implementation spawns forgejo-mcp via supervisor and wires
|
|
// a bridge; tests pass fakes.
|
|
type SpawnFunc func(ctx context.Context, sess *oauth.Session) (*Backend, error)
|
|
|
|
// Config bundles the inputs to New.
|
|
type Config struct {
|
|
Spawn SpawnFunc
|
|
MaxSessions int // 0 means unlimited
|
|
Log *slog.Logger // optional; defaults to discard
|
|
Now func() time.Time // optional; defaults to time.Now
|
|
}
|
|
|
|
// Registry tracks active sessions, dispatches requests to them, and tears
|
|
// them down on broker shutdown. Construct via New — the embedded sync.Map
|
|
// makes value copies unsafe.
|
|
type Registry struct {
|
|
spawn SpawnFunc
|
|
maxSessions int
|
|
log *slog.Logger
|
|
now func() time.Time
|
|
|
|
sessions sync.Map // string sid → *entry
|
|
count atomic.Int32
|
|
}
|
|
|
|
type entry struct {
|
|
sid string
|
|
backend atomic.Pointer[Backend] // swapped on rotation; readers use Load
|
|
lastActive atomic.Int64 // unix nanoseconds; bumped per request
|
|
|
|
mu sync.Mutex // guards oauthSess; backend swap holds this too
|
|
oauthSess *oauth.Session
|
|
}
|
|
|
|
// SessionIDHeader is the streamable-HTTP MCP header that ferries the
|
|
// session id between client and server.
|
|
const SessionIDHeader = "Mcp-Session-Id"
|
|
|
|
// New validates the config and returns a Registry ready to serve.
|
|
func New(cfg Config) (*Registry, error) {
|
|
if cfg.Spawn == nil {
|
|
return nil, errors.New("session: Spawn is required")
|
|
}
|
|
log := cfg.Log
|
|
if log == nil {
|
|
log = slog.New(slog.DiscardHandler)
|
|
}
|
|
now := cfg.Now
|
|
if now == nil {
|
|
now = time.Now
|
|
}
|
|
return &Registry{
|
|
spawn: cfg.Spawn,
|
|
maxSessions: cfg.MaxSessions,
|
|
log: log,
|
|
now: now,
|
|
}, nil
|
|
}
|
|
|
|
// Handler returns an http.Handler suitable for mounting at /mcp. Wrap it
|
|
// with oauth.Authenticator.RequireBearer so every request carries a
|
|
// validated Session in its context.
|
|
func (r *Registry) Handler() http.Handler {
|
|
return http.HandlerFunc(r.serve)
|
|
}
|
|
|
|
func (r *Registry) serve(w http.ResponseWriter, req *http.Request) {
|
|
oauthSess, ok := oauth.SessionFromContext(req.Context())
|
|
if !ok {
|
|
// Mounted without RequireBearer; treat as a programmer error.
|
|
http.Error(w, "no auth session in context", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
sid := req.Header.Get(SessionIDHeader)
|
|
if sid == "" {
|
|
// No session id yet — this must be an `initialize`. Mint a session.
|
|
e, err := r.spawnSession(req.Context(), oauthSess)
|
|
if err != nil {
|
|
r.respondSpawnError(w, err)
|
|
return
|
|
}
|
|
w.Header().Set(SessionIDHeader, e.sid)
|
|
e.lastActive.Store(r.now().UnixNano())
|
|
e.backend.Load().Handler.ServeHTTP(w, req)
|
|
return
|
|
}
|
|
|
|
// Lookup; reject unknown session ids with 410 Gone so the client
|
|
// re-initialises rather than retrying forever (per MCP guidance).
|
|
v, ok := r.sessions.Load(sid)
|
|
if !ok {
|
|
http.Error(w, "unknown or expired session", http.StatusGone)
|
|
return
|
|
}
|
|
e := v.(*entry)
|
|
if e.snapshotOAuth().BrokerTokenHash != oauthSess.BrokerTokenHash {
|
|
// Session id is bound to the OAuth token that minted it. A
|
|
// different bearer probing a stolen sid gets 403 — not 401, so
|
|
// this is distinct from "your token is bad" and from "we don't
|
|
// know that sid" (410 above). Defence in depth.
|
|
http.Error(w, "session bound to a different token", http.StatusForbidden)
|
|
return
|
|
}
|
|
e.lastActive.Store(r.now().UnixNano())
|
|
e.backend.Load().Handler.ServeHTTP(w, req)
|
|
}
|
|
|
|
// snapshotOAuth returns a pointer to the entry's current oauthSess under
|
|
// lock so callers don't see partial swaps during rotation.
|
|
func (e *entry) snapshotOAuth() *oauth.Session {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
return e.oauthSess
|
|
}
|
|
|
|
func (r *Registry) spawnSession(ctx context.Context, oauthSess *oauth.Session) (*entry, error) {
|
|
if r.maxSessions > 0 && int(r.count.Load()) >= r.maxSessions {
|
|
return nil, errMaxSessions
|
|
}
|
|
|
|
backend, err := r.spawn(ctx, oauthSess)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sid := newSessionID()
|
|
e := &entry{sid: sid, oauthSess: oauthSess}
|
|
e.backend.Store(backend)
|
|
e.lastActive.Store(r.now().UnixNano())
|
|
|
|
if _, loaded := r.sessions.LoadOrStore(sid, e); loaded {
|
|
// Astronomically unlikely (24-byte random collision); roll back.
|
|
_ = backend.Stop(ctx)
|
|
return nil, errors.New("session: id collision")
|
|
}
|
|
r.count.Add(1)
|
|
|
|
r.watchBackend(sid, backend)
|
|
return e, nil
|
|
}
|
|
|
|
// watchBackend launches a goroutine that removes the session if the given
|
|
// backend's Done closes WHILE that backend is still the entry's current
|
|
// one. After a rotation, the old backend's Done eventually closes too,
|
|
// but the entry now points at a new backend; in that case the watcher
|
|
// is a no-op so the session survives the rotation.
|
|
func (r *Registry) watchBackend(sid string, backend *Backend) {
|
|
go func() {
|
|
<-backend.Done
|
|
v, ok := r.sessions.Load(sid)
|
|
if !ok {
|
|
return
|
|
}
|
|
e := v.(*entry)
|
|
if e.backend.Load() == backend {
|
|
r.removeSession(sid)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (r *Registry) removeSession(sid string) {
|
|
if _, ok := r.sessions.LoadAndDelete(sid); ok {
|
|
r.count.Add(-1)
|
|
}
|
|
}
|
|
|
|
func (r *Registry) respondSpawnError(w http.ResponseWriter, err error) {
|
|
if errors.Is(err, errMaxSessions) {
|
|
w.Header().Set("Retry-After", "30")
|
|
http.Error(w, "broker at max sessions", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
r.log.Error("session spawn failed", slog.String("err", err.Error()))
|
|
http.Error(w, "session spawn failed", http.StatusInternalServerError)
|
|
}
|
|
|
|
// Active returns the number of currently-tracked sessions. Mostly for
|
|
// tests and metrics.
|
|
func (r *Registry) Active() int { return int(r.count.Load()) }
|
|
|
|
// Stop tears down every active session. Used at broker shutdown so
|
|
// children don't leak past the parent. Best-effort: stop errors are
|
|
// logged but not aggregated.
|
|
func (r *Registry) Stop(ctx context.Context) {
|
|
r.sessions.Range(func(k, v any) bool {
|
|
e := v.(*entry)
|
|
if err := e.backend.Load().Stop(ctx); err != nil {
|
|
r.log.Warn("session stop", slog.String("sid", e.sid), slog.String("err", err.Error()))
|
|
}
|
|
r.sessions.Delete(k)
|
|
r.count.Add(-1)
|
|
return true
|
|
})
|
|
}
|
|
|
|
// LastActive returns the last-active wall-clock time for sid, or the
|
|
// zero time if no such session exists. Phase 5c's reaper uses this to
|
|
// evict idle sessions.
|
|
func (r *Registry) LastActive(sid string) time.Time {
|
|
v, ok := r.sessions.Load(sid)
|
|
if !ok {
|
|
return time.Time{}
|
|
}
|
|
return time.Unix(0, v.(*entry).lastActive.Load())
|
|
}
|
|
|
|
// errMaxSessions signals the cap was hit. Internal sentinel only —
|
|
// callers see a 503.
|
|
var errMaxSessions = errors.New("session: max sessions reached")
|
|
|
|
// newSessionID returns a hex-encoded 24-byte random id (48 hex chars,
|
|
// 192 bits of entropy). Plenty for global uniqueness across a broker.
|
|
func newSessionID() string {
|
|
b := make([]byte, 24)
|
|
if _, err := rand.Read(b); err != nil {
|
|
panic("session: crypto/rand failed: " + err.Error())
|
|
}
|
|
return hex.EncodeToString(b)
|
|
}
|