235 lines
7.1 KiB
Go
235 lines
7.1 KiB
Go
|
|
// Package session is the broker's MCP session glue. It maps the
|
||
|
|
// `Mcp-Session-Id` header onto a running forgejo-mcp subprocess (managed
|
||
|
|
// by internal/supervisor) plus a bridge (internal/bridge) that pipes
|
||
|
|
// JSON-RPC traffic.
|
||
|
|
//
|
||
|
|
// One Registry handles all sessions. New session ids are minted on the
|
||
|
|
// first /mcp POST that arrives without a session header; subsequent
|
||
|
|
// requests with that header are dispatched to the same backend so a
|
||
|
|
// single user keeps the same forgejo-mcp child for the life of the
|
||
|
|
// session.
|
||
|
|
//
|
||
|
|
// The Registry knows how to spawn — it does not know how. Phase-5a tests
|
||
|
|
// inject fake SpawnFuncs to exercise the lifecycle without forking real
|
||
|
|
// processes. A production wiring lives in cmd/broker (phase 5c will
|
||
|
|
// finalise that).
|
||
|
|
package session
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"crypto/rand"
|
||
|
|
"encoding/hex"
|
||
|
|
"errors"
|
||
|
|
"log/slog"
|
||
|
|
"net/http"
|
||
|
|
"sync"
|
||
|
|
"sync/atomic"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Backend is the runtime view of one forgejo-mcp subprocess plus its
|
||
|
|
// bridge. The Registry calls Handler.ServeHTTP for each /mcp request
|
||
|
|
// belonging to the session and Stop when the session is reaped.
|
||
|
|
type Backend struct {
|
||
|
|
Handler http.Handler
|
||
|
|
Stop func(ctx context.Context) error
|
||
|
|
Done <-chan struct{}
|
||
|
|
}
|
||
|
|
|
||
|
|
// SpawnFunc constructs a Backend for the supplied OAuth session. The
|
||
|
|
// production implementation spawns forgejo-mcp via supervisor and wires
|
||
|
|
// a bridge; tests pass fakes.
|
||
|
|
type SpawnFunc func(ctx context.Context, sess *oauth.Session) (*Backend, error)
|
||
|
|
|
||
|
|
// Config bundles the inputs to New.
|
||
|
|
type Config struct {
|
||
|
|
Spawn SpawnFunc
|
||
|
|
MaxSessions int // 0 means unlimited
|
||
|
|
Log *slog.Logger // optional; defaults to discard
|
||
|
|
Now func() time.Time // optional; defaults to time.Now
|
||
|
|
}
|
||
|
|
|
||
|
|
// Registry tracks active sessions, dispatches requests to them, and tears
|
||
|
|
// them down on broker shutdown. Construct via New — the embedded sync.Map
|
||
|
|
// makes value copies unsafe.
|
||
|
|
type Registry struct {
|
||
|
|
spawn SpawnFunc
|
||
|
|
maxSessions int
|
||
|
|
log *slog.Logger
|
||
|
|
now func() time.Time
|
||
|
|
|
||
|
|
sessions sync.Map // string sid → *entry
|
||
|
|
count atomic.Int32
|
||
|
|
}
|
||
|
|
|
||
|
|
type entry struct {
|
||
|
|
sid string
|
||
|
|
backend *Backend
|
||
|
|
lastActive atomic.Int64 // unix nanoseconds; bumped per request
|
||
|
|
oauthSess *oauth.Session
|
||
|
|
}
|
||
|
|
|
||
|
|
// SessionIDHeader is the streamable-HTTP MCP header that ferries the
|
||
|
|
// session id between client and server.
|
||
|
|
const SessionIDHeader = "Mcp-Session-Id"
|
||
|
|
|
||
|
|
// New validates the config and returns a Registry ready to serve.
|
||
|
|
func New(cfg Config) (*Registry, error) {
|
||
|
|
if cfg.Spawn == nil {
|
||
|
|
return nil, errors.New("session: Spawn is required")
|
||
|
|
}
|
||
|
|
log := cfg.Log
|
||
|
|
if log == nil {
|
||
|
|
log = slog.New(slog.DiscardHandler)
|
||
|
|
}
|
||
|
|
now := cfg.Now
|
||
|
|
if now == nil {
|
||
|
|
now = time.Now
|
||
|
|
}
|
||
|
|
return &Registry{
|
||
|
|
spawn: cfg.Spawn,
|
||
|
|
maxSessions: cfg.MaxSessions,
|
||
|
|
log: log,
|
||
|
|
now: now,
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Handler returns an http.Handler suitable for mounting at /mcp. Wrap it
|
||
|
|
// with oauth.Authenticator.RequireBearer so every request carries a
|
||
|
|
// validated Session in its context.
|
||
|
|
func (r *Registry) Handler() http.Handler {
|
||
|
|
return http.HandlerFunc(r.serve)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *Registry) serve(w http.ResponseWriter, req *http.Request) {
|
||
|
|
oauthSess, ok := oauth.SessionFromContext(req.Context())
|
||
|
|
if !ok {
|
||
|
|
// Mounted without RequireBearer; treat as a programmer error.
|
||
|
|
http.Error(w, "no auth session in context", http.StatusInternalServerError)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
sid := req.Header.Get(SessionIDHeader)
|
||
|
|
if sid == "" {
|
||
|
|
// No session id yet — this must be an `initialize`. Mint a session.
|
||
|
|
e, err := r.spawnSession(req.Context(), oauthSess)
|
||
|
|
if err != nil {
|
||
|
|
r.respondSpawnError(w, err)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
w.Header().Set(SessionIDHeader, e.sid)
|
||
|
|
e.lastActive.Store(r.now().UnixNano())
|
||
|
|
e.backend.Handler.ServeHTTP(w, req)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Lookup; reject unknown session ids with 410 Gone so the client
|
||
|
|
// re-initialises rather than retrying forever (per MCP guidance).
|
||
|
|
v, ok := r.sessions.Load(sid)
|
||
|
|
if !ok {
|
||
|
|
http.Error(w, "unknown or expired session", http.StatusGone)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
e := v.(*entry)
|
||
|
|
if e.oauthSess.BrokerTokenHash != oauthSess.BrokerTokenHash {
|
||
|
|
// Session id is bound to the OAuth token that minted it. A
|
||
|
|
// different bearer probing a stolen sid gets 403 — not 401, so
|
||
|
|
// this is distinct from "your token is bad" and from "we don't
|
||
|
|
// know that sid" (410 above). Defence in depth.
|
||
|
|
http.Error(w, "session bound to a different token", http.StatusForbidden)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
e.lastActive.Store(r.now().UnixNano())
|
||
|
|
e.backend.Handler.ServeHTTP(w, req)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *Registry) spawnSession(ctx context.Context, oauthSess *oauth.Session) (*entry, error) {
|
||
|
|
if r.maxSessions > 0 && int(r.count.Load()) >= r.maxSessions {
|
||
|
|
return nil, errMaxSessions
|
||
|
|
}
|
||
|
|
|
||
|
|
backend, err := r.spawn(ctx, oauthSess)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
sid := newSessionID()
|
||
|
|
e := &entry{sid: sid, backend: backend, oauthSess: oauthSess}
|
||
|
|
e.lastActive.Store(r.now().UnixNano())
|
||
|
|
|
||
|
|
if _, loaded := r.sessions.LoadOrStore(sid, e); loaded {
|
||
|
|
// Astronomically unlikely (24-byte random collision); roll back.
|
||
|
|
_ = backend.Stop(ctx)
|
||
|
|
return nil, errors.New("session: id collision")
|
||
|
|
}
|
||
|
|
r.count.Add(1)
|
||
|
|
|
||
|
|
// When the child exits on its own (crash, OOM, etc.), reap the entry.
|
||
|
|
go func() {
|
||
|
|
<-backend.Done
|
||
|
|
r.removeSession(sid)
|
||
|
|
}()
|
||
|
|
return e, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *Registry) removeSession(sid string) {
|
||
|
|
if _, ok := r.sessions.LoadAndDelete(sid); ok {
|
||
|
|
r.count.Add(-1)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *Registry) respondSpawnError(w http.ResponseWriter, err error) {
|
||
|
|
if errors.Is(err, errMaxSessions) {
|
||
|
|
w.Header().Set("Retry-After", "30")
|
||
|
|
http.Error(w, "broker at max sessions", http.StatusServiceUnavailable)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
r.log.Error("session spawn failed", slog.String("err", err.Error()))
|
||
|
|
http.Error(w, "session spawn failed", http.StatusInternalServerError)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Active returns the number of currently-tracked sessions. Mostly for
|
||
|
|
// tests and metrics.
|
||
|
|
func (r *Registry) Active() int { return int(r.count.Load()) }
|
||
|
|
|
||
|
|
// Stop tears down every active session. Used at broker shutdown so
|
||
|
|
// children don't leak past the parent. Best-effort: stop errors are
|
||
|
|
// logged but not aggregated.
|
||
|
|
func (r *Registry) Stop(ctx context.Context) {
|
||
|
|
r.sessions.Range(func(k, v any) bool {
|
||
|
|
e := v.(*entry)
|
||
|
|
if err := e.backend.Stop(ctx); err != nil {
|
||
|
|
r.log.Warn("session stop", slog.String("sid", e.sid), slog.String("err", err.Error()))
|
||
|
|
}
|
||
|
|
r.sessions.Delete(k)
|
||
|
|
r.count.Add(-1)
|
||
|
|
return true
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
// LastActive returns the last-active wall-clock time for sid, or the
|
||
|
|
// zero time if no such session exists. Phase 5c's reaper uses this to
|
||
|
|
// evict idle sessions.
|
||
|
|
func (r *Registry) LastActive(sid string) time.Time {
|
||
|
|
v, ok := r.sessions.Load(sid)
|
||
|
|
if !ok {
|
||
|
|
return time.Time{}
|
||
|
|
}
|
||
|
|
return time.Unix(0, v.(*entry).lastActive.Load())
|
||
|
|
}
|
||
|
|
|
||
|
|
// errMaxSessions signals the cap was hit. Internal sentinel only —
|
||
|
|
// callers see a 503.
|
||
|
|
var errMaxSessions = errors.New("session: max sessions reached")
|
||
|
|
|
||
|
|
// newSessionID returns a hex-encoded 24-byte random id (48 hex chars,
|
||
|
|
// 192 bits of entropy). Plenty for global uniqueness across a broker.
|
||
|
|
func newSessionID() string {
|
||
|
|
b := make([]byte, 24)
|
||
|
|
if _, err := rand.Read(b); err != nil {
|
||
|
|
panic("session: crypto/rand failed: " + err.Error())
|
||
|
|
}
|
||
|
|
return hex.EncodeToString(b)
|
||
|
|
}
|