package oauth import ( "context" "database/sql" "errors" "fmt" "net/http" "strings" "time" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/store" ) // Session is the per-request OAuth context attached by the bearer // middleware. Downstream handlers (the MCP endpoint, in phase 5a) read // the upstream Forgejo token from here to spawn forgejo-mcp subprocesses // scoped to the right user. type Session struct { ClientID string ForgejoUserID int64 ForgejoUsername string Scopes string BrokerTokenHash string // SHA-256 hex of the broker token; for log correlation ForgejoToken string // plaintext upstream token — keep in memory, never log ForgejoRefresh string ForgejoTokenExp time.Time } // Authenticator resolves Bearer tokens against the access_tokens table. // Use Authenticator.RequireBearer to wrap the protected handler. type Authenticator struct { Store *store.Store Now func() time.Time // optional; defaults to time.Now } type sessionCtxKey struct{} // SessionFromContext returns the Session attached by RequireBearer, if any. func SessionFromContext(ctx context.Context) (*Session, bool) { s, ok := ctx.Value(sessionCtxKey{}).(*Session) return s, ok } // ContextWithSession attaches a Session to ctx using the same key // RequireBearer uses. Primarily useful in tests that want to drive a // gated handler without standing up the full OAuth flow. func ContextWithSession(ctx context.Context, s *Session) context.Context { return context.WithValue(ctx, sessionCtxKey{}, s) } // RequireBearer is HTTP middleware that: // 1. Demands an `Authorization: Bearer ` header. // 2. Looks the token up by SHA-256 hash in access_tokens. // 3. Rejects expired or revoked tokens. // 4. Attaches the resolved Session to the request context for downstream // handlers to read via SessionFromContext. // // Failures emit a 401 with an RFC 6750 §3 WWW-Authenticate header carrying // the appropriate error code (invalid_token / invalid_request). func (a *Authenticator) RequireBearer(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { now := a.Now if now == nil { now = time.Now } raw := r.Header.Get("Authorization") if raw == "" { respondAuthError(w, "invalid_request", "missing Authorization header") return } token, ok := strings.CutPrefix(raw, "Bearer ") if !ok || token == "" { respondAuthError(w, "invalid_request", "Authorization header must use Bearer scheme") return } sess, err := a.lookupSession(r.Context(), hashToken(token), now()) if err != nil { switch { case errors.Is(err, errTokenNotFound): respondAuthError(w, "invalid_token", "unknown token") case errors.Is(err, errTokenExpired): respondAuthError(w, "invalid_token", "token expired") case errors.Is(err, errTokenRevoked): respondAuthError(w, "invalid_token", "token revoked") default: // Unexpected DB or scan error: don't leak internals to // the caller. Logging would land in middleware-of-the- // future once we wire a logger here. respondAuthError(w, "invalid_token", "auth lookup failed") } return } ctx := context.WithValue(r.Context(), sessionCtxKey{}, sess) next.ServeHTTP(w, r.WithContext(ctx)) }) } // Sentinel errors so RequireBearer can render distinct WWW-Authenticate // reasons for the operator while always returning 401 to the client. var ( errTokenNotFound = errors.New("token not found") errTokenExpired = errors.New("token expired") errTokenRevoked = errors.New("token revoked") ) func (a *Authenticator) lookupSession(ctx context.Context, tokenHash string, now time.Time) (*Session, error) { var ( clientID, fjUsername, scopes, fjAccess, fjRefresh string fjUserID int64 expiresAt, fjExpiresAt int64 revokedAt sql.NullInt64 ) row := a.Store.DB().QueryRowContext(ctx, `SELECT client_id, forgejo_user_id, forgejo_username, scopes, forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at, expires_at, revoked_at FROM access_tokens WHERE token_hash = ?`, tokenHash) err := row.Scan(&clientID, &fjUserID, &fjUsername, &scopes, &fjAccess, &fjRefresh, &fjExpiresAt, &expiresAt, &revokedAt) if errors.Is(err, sql.ErrNoRows) { return nil, errTokenNotFound } if err != nil { return nil, err } if revokedAt.Valid { return nil, errTokenRevoked } if now.Unix() > expiresAt { return nil, errTokenExpired } return &Session{ ClientID: clientID, ForgejoUserID: fjUserID, ForgejoUsername: fjUsername, Scopes: scopes, BrokerTokenHash: tokenHash, ForgejoToken: fjAccess, ForgejoRefresh: fjRefresh, ForgejoTokenExp: time.Unix(fjExpiresAt, 0).UTC(), }, nil } // respondAuthError writes a 401 with a WWW-Authenticate header per RFC 6750 // §3. The body stays empty — error info goes in the header so it's discoverable // to compliant clients without leaking detail in a body that browsers might // render. func respondAuthError(w http.ResponseWriter, errorCode, description string) { w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer error="%s", error_description="%s"`, escapeHeader(errorCode), escapeHeader(description))) w.WriteHeader(http.StatusUnauthorized) } // escapeHeader strips characters that would break a quoted-string in an // HTTP header value. Conservative: only allow safe ASCII. The error codes // we emit are well-known constants, so this is a defense-in-depth check // against a future bug accidentally interpolating user input. func escapeHeader(s string) string { var b strings.Builder for _, c := range s { if c >= 0x20 && c < 0x7f && c != '"' && c != '\\' { b.WriteRune(c) } } return b.String() }