forgejo-mcp-broker/internal/oauth/auth.go

172 lines
5.7 KiB
Go
Raw Permalink Normal View History

package oauth
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"strings"
"time"
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/store"
)
// Session is the per-request OAuth context attached by the bearer
// middleware. Downstream handlers (the MCP endpoint, in phase 5a) read
// the upstream Forgejo token from here to spawn forgejo-mcp subprocesses
// scoped to the right user.
type Session struct {
ClientID string
ForgejoUserID int64
ForgejoUsername string
Scopes string
BrokerTokenHash string // SHA-256 hex of the broker token; for log correlation
ForgejoToken string // plaintext upstream token — keep in memory, never log
ForgejoRefresh string
ForgejoTokenExp time.Time
}
// Authenticator resolves Bearer tokens against the access_tokens table.
// Use Authenticator.RequireBearer to wrap the protected handler.
type Authenticator struct {
Store *store.Store
Now func() time.Time // optional; defaults to time.Now
}
type sessionCtxKey struct{}
// SessionFromContext returns the Session attached by RequireBearer, if any.
func SessionFromContext(ctx context.Context) (*Session, bool) {
s, ok := ctx.Value(sessionCtxKey{}).(*Session)
return s, ok
}
feat(session): MCP session registry + spawn-on-initialize (forgejo-mcp-broker-t81) Adds internal/session.Registry, the MCP session glue that maps Mcp-Session-Id to a running forgejo-mcp child + bridge. Lifecycle: - First /mcp POST without Mcp-Session-Id: SpawnFunc creates a backend (in production: supervisor.Start + bridge.New); registry mints a 192-bit hex session id, attaches it to the response header, and dispatches the request to the new backend. - Subsequent POSTs with the header dispatch to the existing backend. - Unknown sids → 410 Gone (per MCP guidance, so clients re-initialise instead of retrying forever). - Sids are bound to the OAuth token that minted them: a different bearer probing a stolen sid gets 403, distinct from "your token is bad" (401) and "sid unknown" (410). Cleanup: - When backend.Done closes (child exited on its own — crash, OOM, user-driven shutdown), a goroutine reaps the entry. - Stop tears every session down on broker shutdown. The 30s idle reaper and Forgejo token rotation come in 5c. The Registry is decoupled from supervisor and bridge via SpawnFunc, so tests don't need to fork real processes — they hand the registry a fake that returns a controllable Backend. Also added oauth.ContextWithSession so the session tests can inject an oauth.Session into request contexts without standing up the full bearer-middleware chain. Tests: 83.3% coverage. Cover spawn-on-initialize, sid reuse, unknown sid, max-session cap with Retry-After, no-auth-context guard, sid hijack defense (token mismatch → 403), Done-channel reaping, and graceful Stop. Closes forgejo-mcp-broker-t81. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 17:24:25 +02:00
// ContextWithSession attaches a Session to ctx using the same key
// RequireBearer uses. Primarily useful in tests that want to drive a
// gated handler without standing up the full OAuth flow.
func ContextWithSession(ctx context.Context, s *Session) context.Context {
return context.WithValue(ctx, sessionCtxKey{}, s)
}
// RequireBearer is HTTP middleware that:
// 1. Demands an `Authorization: Bearer <token>` header.
// 2. Looks the token up by SHA-256 hash in access_tokens.
// 3. Rejects expired or revoked tokens.
// 4. Attaches the resolved Session to the request context for downstream
// handlers to read via SessionFromContext.
//
// Failures emit a 401 with an RFC 6750 §3 WWW-Authenticate header carrying
// the appropriate error code (invalid_token / invalid_request).
func (a *Authenticator) RequireBearer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
now := a.Now
if now == nil {
now = time.Now
}
raw := r.Header.Get("Authorization")
if raw == "" {
respondAuthError(w, "invalid_request", "missing Authorization header")
return
}
token, ok := strings.CutPrefix(raw, "Bearer ")
if !ok || token == "" {
respondAuthError(w, "invalid_request", "Authorization header must use Bearer scheme")
return
}
sess, err := a.lookupSession(r.Context(), hashToken(token), now())
if err != nil {
switch {
case errors.Is(err, errTokenNotFound):
respondAuthError(w, "invalid_token", "unknown token")
case errors.Is(err, errTokenExpired):
respondAuthError(w, "invalid_token", "token expired")
case errors.Is(err, errTokenRevoked):
respondAuthError(w, "invalid_token", "token revoked")
default:
// Unexpected DB or scan error: don't leak internals to
// the caller. Logging would land in middleware-of-the-
// future once we wire a logger here.
respondAuthError(w, "invalid_token", "auth lookup failed")
}
return
}
ctx := context.WithValue(r.Context(), sessionCtxKey{}, sess)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Sentinel errors so RequireBearer can render distinct WWW-Authenticate
// reasons for the operator while always returning 401 to the client.
var (
errTokenNotFound = errors.New("token not found")
errTokenExpired = errors.New("token expired")
errTokenRevoked = errors.New("token revoked")
)
func (a *Authenticator) lookupSession(ctx context.Context, tokenHash string, now time.Time) (*Session, error) {
var (
clientID, fjUsername, scopes, fjAccess, fjRefresh string
fjUserID int64
expiresAt, fjExpiresAt int64
revokedAt sql.NullInt64
)
row := a.Store.DB().QueryRowContext(ctx,
`SELECT client_id, forgejo_user_id, forgejo_username, scopes,
forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at,
expires_at, revoked_at
FROM access_tokens WHERE token_hash = ?`, tokenHash)
err := row.Scan(&clientID, &fjUserID, &fjUsername, &scopes,
&fjAccess, &fjRefresh, &fjExpiresAt, &expiresAt, &revokedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, errTokenNotFound
}
if err != nil {
return nil, err
}
if revokedAt.Valid {
return nil, errTokenRevoked
}
if now.Unix() > expiresAt {
return nil, errTokenExpired
}
return &Session{
ClientID: clientID,
ForgejoUserID: fjUserID,
ForgejoUsername: fjUsername,
Scopes: scopes,
BrokerTokenHash: tokenHash,
ForgejoToken: fjAccess,
ForgejoRefresh: fjRefresh,
ForgejoTokenExp: time.Unix(fjExpiresAt, 0).UTC(),
}, nil
}
// respondAuthError writes a 401 with a WWW-Authenticate header per RFC 6750
// §3. The body stays empty — error info goes in the header so it's discoverable
// to compliant clients without leaking detail in a body that browsers might
// render.
func respondAuthError(w http.ResponseWriter, errorCode, description string) {
w.Header().Set("WWW-Authenticate",
fmt.Sprintf(`Bearer error="%s", error_description="%s"`,
escapeHeader(errorCode), escapeHeader(description)))
w.WriteHeader(http.StatusUnauthorized)
}
// escapeHeader strips characters that would break a quoted-string in an
// HTTP header value. Conservative: only allow safe ASCII. The error codes
// we emit are well-known constants, so this is a defense-in-depth check
// against a future bug accidentally interpolating user input.
func escapeHeader(s string) string {
var b strings.Builder
for _, c := range s {
if c >= 0x20 && c < 0x7f && c != '"' && c != '\\' {
b.WriteRune(c)
}
}
return b.String()
}