Adds RFC 8414 (oauth-authorization-server) and RFC 9728 (oauth- protected-resource) metadata documents. Both URLs are derived from cfg.Issuer at construction time, never from inbound request headers. Test TestDiscovery_IssuerIgnoresHostHeader explicitly probes this — a malicious Host: evil.example.com value must not leak into the published metadata. Defense against the OAuth metadata-spoofing class starts at the discovery layer. Capabilities published reflect the actual OAuth surface: - response_types_supported = ["code"] - grant_types_supported = ["authorization_code", "refresh_token"] - code_challenge_methods_supported = ["S256"] (PKCE only, no plain) - token_endpoint_auth_methods_supported = ["none"] (PKCE-only public clients) Protected-resource metadata advertises /mcp as the resource; phase 5 will mount the gated MCP endpoint there. Closes forgejo-mcp-broker-b2o. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
877 lines
31 KiB
Go
877 lines
31 KiB
Go
// Package oauth implements the broker's OAuth 2.1 authorization server
|
|
// surface — what Claude.ai (and other MCP clients) talk to. User auth is
|
|
// delegated to upstream Forgejo via internal/forgejo.
|
|
//
|
|
// Endpoints (RFC numbers in parentheses):
|
|
// POST /oauth/register — RFC 7591 dynamic client registration
|
|
// GET /oauth/authorize — RFC 6749 / 7636 authorize with PKCE S256
|
|
// GET /oauth/callback — Forgejo redirects back here after user consent
|
|
// POST /oauth/token — authorization_code and refresh_token grants
|
|
// POST /oauth/revoke — RFC 7009 token revocation
|
|
//
|
|
// PKCE: required, S256 only. Plain method is rejected (this is OAuth 2.1).
|
|
//
|
|
// Token storage: every broker-issued access/refresh token is stored as a
|
|
// hex-encoded SHA-256 hash. The plaintext leaves the broker exactly once —
|
|
// in the body of the /oauth/token response.
|
|
//
|
|
// Pending authorizations (between /authorize and /callback) live in memory
|
|
// with a short TTL. A broker restart drops them, forcing the user to
|
|
// re-authorize; that's an acceptable UX hit for not adding another table.
|
|
package oauth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/forgejo"
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/store"
|
|
)
|
|
|
|
// AuthCodeTTL bounds how long a broker-issued authorization code stays
|
|
// usable. RFC 6749 §10.5 recommends "very short" — 10 minutes matches what
|
|
// most ASes use.
|
|
const AuthCodeTTL = 10 * time.Minute
|
|
|
|
// PendingAuthTTL caps how long an in-flight /authorize → /callback flow can
|
|
// take. Anything longer is almost certainly an abandoned attempt.
|
|
const PendingAuthTTL = 10 * time.Minute
|
|
|
|
// AccessTokenTTL is the lifetime of broker access tokens. Refresh tokens
|
|
// last considerably longer (see RefreshTokenTTL).
|
|
const AccessTokenTTL = 1 * time.Hour
|
|
|
|
// RefreshTokenTTL is the lifetime of broker refresh tokens. 30 days lets a
|
|
// daily-active user stay logged in indefinitely while bounded enough that
|
|
// theft via stale backup is time-limited.
|
|
const RefreshTokenTTL = 30 * 24 * time.Hour
|
|
|
|
// Server is the OAuth authorization server. Construct one with NewServer
|
|
// and mount its Handler under the broker's HTTP mux.
|
|
type Server struct {
|
|
store *store.Store
|
|
forgejo *forgejo.Client
|
|
issuer string // public URL, e.g. https://mcp.example.com — never derived from headers
|
|
scopes string // space-separated scope set requested from upstream Forgejo
|
|
now func() time.Time
|
|
log *slog.Logger
|
|
mu sync.Mutex // guards pendingAuths cleanup; map ops use sync.Map natively
|
|
pending sync.Map // forgejoState string → *pendingAuth
|
|
stopCh chan struct{}
|
|
}
|
|
|
|
// Config bundles Server dependencies.
|
|
type Config struct {
|
|
Store *store.Store
|
|
Forgejo *forgejo.Client
|
|
Issuer string // required; e.g. https://mcp.example.com
|
|
Scopes string // optional; space-separated; defaults to ""
|
|
Now func() time.Time
|
|
Log *slog.Logger
|
|
}
|
|
|
|
// NewServer validates the config and starts the periodic pending-auth
|
|
// reaper. Stop the reaper with Server.Close.
|
|
func NewServer(cfg Config) (*Server, error) {
|
|
if cfg.Store == nil {
|
|
return nil, errors.New("oauth: Store is required")
|
|
}
|
|
if cfg.Forgejo == nil {
|
|
return nil, errors.New("oauth: Forgejo client is required")
|
|
}
|
|
if cfg.Issuer == "" {
|
|
return nil, errors.New("oauth: Issuer is required")
|
|
}
|
|
now := cfg.Now
|
|
if now == nil {
|
|
now = time.Now
|
|
}
|
|
logger := cfg.Log
|
|
if logger == nil {
|
|
logger = slog.New(slog.DiscardHandler)
|
|
}
|
|
s := &Server{
|
|
store: cfg.Store,
|
|
forgejo: cfg.Forgejo,
|
|
issuer: strings.TrimRight(cfg.Issuer, "/"),
|
|
scopes: cfg.Scopes,
|
|
now: now,
|
|
log: logger,
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
go s.reapPendingLoop()
|
|
return s, nil
|
|
}
|
|
|
|
// Close stops the pending-auth reaper. Safe to call multiple times.
|
|
func (s *Server) Close() {
|
|
select {
|
|
case <-s.stopCh:
|
|
// already closed
|
|
default:
|
|
close(s.stopCh)
|
|
}
|
|
}
|
|
|
|
// Handler returns the http.Handler exposing OAuth endpoints plus the two
|
|
// discovery documents. The broker's outer mux should mount this at the
|
|
// root (not under /oauth) so the .well-known paths land at their spec-
|
|
// mandated location.
|
|
func (s *Server) Handler() http.Handler {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("POST /oauth/register", s.handleRegister)
|
|
mux.HandleFunc("GET /oauth/authorize", s.handleAuthorize)
|
|
mux.HandleFunc("GET /oauth/callback", s.handleCallback)
|
|
mux.HandleFunc("POST /oauth/token", s.handleToken)
|
|
mux.HandleFunc("POST /oauth/revoke", s.handleRevoke)
|
|
mux.HandleFunc("GET /.well-known/oauth-authorization-server", s.handleASMetadata)
|
|
mux.HandleFunc("GET /.well-known/oauth-protected-resource", s.handlePRMetadata)
|
|
return mux
|
|
}
|
|
|
|
// pendingAuth is the in-memory record of a /authorize → /callback flow.
|
|
type pendingAuth struct {
|
|
clientID string
|
|
redirectURI string
|
|
codeChallenge string
|
|
codeChallengeMethod string
|
|
scopes string
|
|
clientState string
|
|
expiresAt time.Time
|
|
}
|
|
|
|
// ============================================================================
|
|
// Helpers — token generation, hashing, OAuth error responses
|
|
// ============================================================================
|
|
|
|
// secureToken returns a hex-encoded cryptographically-random string of the
|
|
// given byte length. 32 bytes ⇒ 256 bits of entropy ⇒ 64 hex chars.
|
|
func secureToken(nBytes int) string {
|
|
b := make([]byte, nBytes)
|
|
if _, err := rand.Read(b); err != nil {
|
|
// crypto/rand failing is a system-level emergency. Panicking here is
|
|
// the right move — operating without entropy is worse than crashing.
|
|
panic("oauth: crypto/rand failed: " + err.Error())
|
|
}
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
// hashToken returns the hex-encoded SHA-256 of the given token. Used at the
|
|
// store boundary so plaintext tokens never persist.
|
|
func hashToken(token string) string {
|
|
sum := sha256.Sum256([]byte(token))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
// verifyPKCE returns true iff base64url(sha256(verifier)) equals challenge.
|
|
// Constant-time comparison prevents timing leaks on the verifier.
|
|
func verifyPKCE(verifier, challenge string) bool {
|
|
sum := sha256.Sum256([]byte(verifier))
|
|
got := base64.RawURLEncoding.EncodeToString(sum[:])
|
|
return subtle.ConstantTimeCompare([]byte(got), []byte(challenge)) == 1
|
|
}
|
|
|
|
// writeJSON writes obj as JSON with the given status. Errors are logged but
|
|
// not surfaced — by the time encoding could fail, headers are already out.
|
|
func writeJSON(w http.ResponseWriter, status int, obj any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
w.WriteHeader(status)
|
|
_ = json.NewEncoder(w).Encode(obj)
|
|
}
|
|
|
|
// writeOAuthError renders an RFC 6749 §5.2 error response.
|
|
func writeOAuthError(w http.ResponseWriter, status int, code, description string) {
|
|
writeJSON(w, status, map[string]string{
|
|
"error": code,
|
|
"error_description": description,
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// Pending-auth reaper
|
|
// ============================================================================
|
|
|
|
func (s *Server) reapPendingLoop() {
|
|
t := time.NewTicker(time.Minute)
|
|
defer t.Stop()
|
|
for {
|
|
select {
|
|
case <-s.stopCh:
|
|
return
|
|
case <-t.C:
|
|
s.reapPending()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) reapPending() {
|
|
now := s.now()
|
|
s.pending.Range(func(k, v any) bool {
|
|
if pa, ok := v.(*pendingAuth); ok && now.After(pa.expiresAt) {
|
|
s.pending.Delete(k)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// /.well-known/* — discovery metadata (RFC 8414, RFC 9728)
|
|
// ============================================================================
|
|
//
|
|
// All URLs in these documents are built from the configured issuer, never
|
|
// from inbound request headers. Publishing an attacker-controlled issuer
|
|
// is a classic OAuth metadata-spoofing vector — defending against it has
|
|
// to start at the discovery layer.
|
|
|
|
type asMetadata struct {
|
|
Issuer string `json:"issuer"`
|
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
RegistrationEndpoint string `json:"registration_endpoint"`
|
|
RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
|
|
ResponseTypesSupported []string `json:"response_types_supported"`
|
|
GrantTypesSupported []string `json:"grant_types_supported"`
|
|
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
|
|
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
|
|
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
|
}
|
|
|
|
func (s *Server) handleASMetadata(w http.ResponseWriter, r *http.Request) {
|
|
md := asMetadata{
|
|
Issuer: s.issuer,
|
|
AuthorizationEndpoint: s.issuer + "/oauth/authorize",
|
|
TokenEndpoint: s.issuer + "/oauth/token",
|
|
RegistrationEndpoint: s.issuer + "/oauth/register",
|
|
RevocationEndpoint: s.issuer + "/oauth/revoke",
|
|
ResponseTypesSupported: []string{"code"},
|
|
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
|
CodeChallengeMethodsSupported: []string{"S256"},
|
|
TokenEndpointAuthMethodsSupported: []string{"none"},
|
|
}
|
|
if s.scopes != "" {
|
|
md.ScopesSupported = strings.Fields(s.scopes)
|
|
}
|
|
writeJSON(w, http.StatusOK, md)
|
|
}
|
|
|
|
type prMetadata struct {
|
|
Resource string `json:"resource"`
|
|
AuthorizationServers []string `json:"authorization_servers"`
|
|
BearerMethodsSupported []string `json:"bearer_methods_supported"`
|
|
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
|
}
|
|
|
|
func (s *Server) handlePRMetadata(w http.ResponseWriter, r *http.Request) {
|
|
md := prMetadata{
|
|
// The protected resource is the MCP endpoint that ships in phase 5.
|
|
// Publishing it here lets clients discover where to send Bearer
|
|
// tokens once they've completed the OAuth dance.
|
|
Resource: s.issuer + "/mcp",
|
|
AuthorizationServers: []string{s.issuer},
|
|
BearerMethodsSupported: []string{"header"},
|
|
}
|
|
if s.scopes != "" {
|
|
md.ScopesSupported = strings.Fields(s.scopes)
|
|
}
|
|
writeJSON(w, http.StatusOK, md)
|
|
}
|
|
|
|
// ============================================================================
|
|
// /oauth/register — RFC 7591 dynamic client registration
|
|
// ============================================================================
|
|
|
|
type registerRequest struct {
|
|
RedirectURIs []string `json:"redirect_uris"`
|
|
ClientName string `json:"client_name,omitempty"`
|
|
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
|
GrantTypes []string `json:"grant_types,omitempty"`
|
|
ResponseTypes []string `json:"response_types,omitempty"`
|
|
Scope string `json:"scope,omitempty"`
|
|
}
|
|
|
|
type registerResponse struct {
|
|
ClientID string `json:"client_id"`
|
|
ClientIDIssuedAt int64 `json:"client_id_issued_at"`
|
|
RedirectURIs []string `json:"redirect_uris"`
|
|
ClientName string `json:"client_name,omitempty"`
|
|
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
|
GrantTypes []string `json:"grant_types"`
|
|
ResponseTypes []string `json:"response_types"`
|
|
}
|
|
|
|
func (s *Server) handleRegister(w http.ResponseWriter, r *http.Request) {
|
|
var req registerRequest
|
|
dec := json.NewDecoder(r.Body)
|
|
dec.DisallowUnknownFields()
|
|
if err := dec.Decode(&req); err != nil {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_client_metadata",
|
|
"could not parse JSON: "+err.Error())
|
|
return
|
|
}
|
|
if len(req.RedirectURIs) == 0 {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_redirect_uri",
|
|
"at least one redirect_uri is required")
|
|
return
|
|
}
|
|
for _, ru := range req.RedirectURIs {
|
|
if err := validateRedirectURI(ru); err != nil {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_redirect_uri", err.Error())
|
|
return
|
|
}
|
|
}
|
|
|
|
clientID := secureToken(16)
|
|
now := s.now().Unix()
|
|
|
|
uris, _ := json.Marshal(req.RedirectURIs)
|
|
meta, _ := json.Marshal(req)
|
|
|
|
if _, err := s.store.DB().ExecContext(r.Context(),
|
|
`INSERT INTO clients (client_id, redirect_uris, metadata_json, created_at)
|
|
VALUES (?, ?, ?, ?)`,
|
|
clientID, string(uris), string(meta), now,
|
|
); err != nil {
|
|
s.log.Error("register: insert client", slog.String("err", err.Error()))
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "client registration failed")
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusCreated, registerResponse{
|
|
ClientID: clientID,
|
|
ClientIDIssuedAt: now,
|
|
RedirectURIs: req.RedirectURIs,
|
|
ClientName: req.ClientName,
|
|
TokenEndpointAuthMethod: "none", // PKCE-only public clients
|
|
GrantTypes: []string{"authorization_code", "refresh_token"},
|
|
ResponseTypes: []string{"code"},
|
|
})
|
|
}
|
|
|
|
func validateRedirectURI(raw string) error {
|
|
u, err := url.Parse(raw)
|
|
if err != nil {
|
|
return fmt.Errorf("redirect_uri %q: %w", raw, err)
|
|
}
|
|
if u.Scheme == "" {
|
|
return fmt.Errorf("redirect_uri %q: missing scheme", raw)
|
|
}
|
|
// RFC 6749 §3.1.2.1 requires absolute URIs; we additionally require
|
|
// http/https or claude.ai's documented custom scheme. Accept anything
|
|
// non-empty for now; tighten later if needed.
|
|
return nil
|
|
}
|
|
|
|
// ============================================================================
|
|
// /oauth/authorize — RFC 6749 + RFC 7636 PKCE
|
|
// ============================================================================
|
|
|
|
func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) {
|
|
q := r.URL.Query()
|
|
|
|
clientID := q.Get("client_id")
|
|
redirectURI := q.Get("redirect_uri")
|
|
responseType := q.Get("response_type")
|
|
clientState := q.Get("state")
|
|
codeChallenge := q.Get("code_challenge")
|
|
codeChallengeMethod := q.Get("code_challenge_method")
|
|
scope := q.Get("scope")
|
|
|
|
// Errors that come BEFORE we know a valid redirect_uri: respond
|
|
// directly. RFC 6749 §3.1.2.4 — if redirect_uri is invalid, do not
|
|
// redirect; render an error instead.
|
|
|
|
if clientID == "" {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_request", "client_id is required")
|
|
return
|
|
}
|
|
registered, err := s.lookupClientRedirectURIs(r.Context(), clientID)
|
|
if err != nil {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_client", "unknown client_id")
|
|
return
|
|
}
|
|
if redirectURI == "" {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_request", "redirect_uri is required")
|
|
return
|
|
}
|
|
if !redirectURIMatches(registered, redirectURI) {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_request",
|
|
"redirect_uri does not match any registered URI")
|
|
return
|
|
}
|
|
|
|
// From here on, errors can be returned via redirect to the client
|
|
// (per RFC 6749 §4.1.2.1). We use that path to surface PKCE/scope
|
|
// problems back to the calling app.
|
|
|
|
if responseType != "code" {
|
|
redirectAuthError(w, r, redirectURI, clientState, "unsupported_response_type",
|
|
"only response_type=code is supported")
|
|
return
|
|
}
|
|
if codeChallenge == "" {
|
|
redirectAuthError(w, r, redirectURI, clientState, "invalid_request",
|
|
"PKCE is required: code_challenge missing")
|
|
return
|
|
}
|
|
if codeChallengeMethod != "S256" {
|
|
redirectAuthError(w, r, redirectURI, clientState, "invalid_request",
|
|
"only code_challenge_method=S256 is supported")
|
|
return
|
|
}
|
|
|
|
// Stash the in-flight authorization. forgejoState is the value we'll
|
|
// pass to Forgejo and read back from /callback.
|
|
forgejoState := secureToken(24)
|
|
s.pending.Store(forgejoState, &pendingAuth{
|
|
clientID: clientID,
|
|
redirectURI: redirectURI,
|
|
codeChallenge: codeChallenge,
|
|
codeChallengeMethod: codeChallengeMethod,
|
|
scopes: scope,
|
|
clientState: clientState,
|
|
expiresAt: s.now().Add(PendingAuthTTL),
|
|
})
|
|
|
|
// Redirect the user-agent to Forgejo. Forgejo asks the user to consent;
|
|
// on success it'll redirect back to our /oauth/callback with code+state.
|
|
upstream := s.forgejo.AuthorizeURL(forgejo.AuthorizeURLOptions{
|
|
RedirectURI: s.issuer + "/oauth/callback",
|
|
State: forgejoState,
|
|
Scopes: s.scopes,
|
|
CodeChallenge: "", // we don't pass our PKCE through; the broker is
|
|
CodeChallengeMethod: "", // a confidential OAuth client of Forgejo.
|
|
})
|
|
http.Redirect(w, r, upstream, http.StatusFound)
|
|
}
|
|
|
|
func (s *Server) lookupClientRedirectURIs(ctx context.Context, clientID string) ([]string, error) {
|
|
var raw string
|
|
row := s.store.DB().QueryRowContext(ctx,
|
|
`SELECT redirect_uris FROM clients WHERE client_id = ?`, clientID)
|
|
if err := row.Scan(&raw); err != nil {
|
|
return nil, err
|
|
}
|
|
var uris []string
|
|
if err := json.Unmarshal([]byte(raw), &uris); err != nil {
|
|
return nil, err
|
|
}
|
|
return uris, nil
|
|
}
|
|
|
|
func redirectURIMatches(registered []string, candidate string) bool {
|
|
for _, r := range registered {
|
|
if r == candidate {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// redirectAuthError sends the user-agent back to the client's redirect_uri
|
|
// with error=... and state=... in the query string, per RFC 6749 §4.1.2.1.
|
|
func redirectAuthError(w http.ResponseWriter, r *http.Request, redirectURI, state, code, description string) {
|
|
u, err := url.Parse(redirectURI)
|
|
if err != nil {
|
|
// Should be unreachable since redirectURIMatches already verified
|
|
// it parses — but be safe.
|
|
writeOAuthError(w, http.StatusBadRequest, code, description)
|
|
return
|
|
}
|
|
q := u.Query()
|
|
q.Set("error", code)
|
|
if description != "" {
|
|
q.Set("error_description", description)
|
|
}
|
|
if state != "" {
|
|
q.Set("state", state)
|
|
}
|
|
u.RawQuery = q.Encode()
|
|
http.Redirect(w, r, u.String(), http.StatusFound)
|
|
}
|
|
|
|
// ============================================================================
|
|
// /oauth/callback — Forgejo redirects here after user consent
|
|
// ============================================================================
|
|
|
|
func (s *Server) handleCallback(w http.ResponseWriter, r *http.Request) {
|
|
q := r.URL.Query()
|
|
state := q.Get("state")
|
|
upstreamCode := q.Get("code")
|
|
upstreamErr := q.Get("error")
|
|
|
|
v, ok := s.pending.LoadAndDelete(state)
|
|
if !ok {
|
|
http.Error(w, "unknown or expired state; please re-authorize", http.StatusBadRequest)
|
|
return
|
|
}
|
|
pa := v.(*pendingAuth)
|
|
if s.now().After(pa.expiresAt) {
|
|
http.Error(w, "authorization expired; please re-authorize", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if upstreamErr != "" {
|
|
desc := q.Get("error_description")
|
|
redirectAuthError(w, r, pa.redirectURI, pa.clientState, upstreamErr, desc)
|
|
return
|
|
}
|
|
if upstreamCode == "" {
|
|
redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error",
|
|
"upstream returned no code")
|
|
return
|
|
}
|
|
|
|
tok, err := s.forgejo.ExchangeCode(r.Context(), upstreamCode, "" /* no PKCE w/ Forgejo */, s.issuer+"/oauth/callback")
|
|
if err != nil {
|
|
s.log.Error("callback: forgejo exchange", slog.String("err", err.Error()))
|
|
redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error",
|
|
"upstream code exchange failed")
|
|
return
|
|
}
|
|
|
|
ui, err := s.forgejo.FetchUserInfo(r.Context(), tok.AccessToken)
|
|
if err != nil {
|
|
s.log.Error("callback: fetch userinfo", slog.String("err", err.Error()))
|
|
redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error",
|
|
"upstream userinfo failed")
|
|
return
|
|
}
|
|
|
|
brokerCode := secureToken(32)
|
|
now := s.now().Unix()
|
|
|
|
if _, err := s.store.DB().ExecContext(r.Context(),
|
|
`INSERT INTO auth_codes
|
|
(code, client_id, redirect_uri, code_challenge, code_challenge_method,
|
|
scopes, forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at,
|
|
forgejo_user_id, forgejo_username, expires_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
brokerCode, pa.clientID, pa.redirectURI, pa.codeChallenge, pa.codeChallengeMethod,
|
|
pa.scopes, tok.AccessToken, tok.RefreshToken,
|
|
s.now().Add(time.Duration(tok.ExpiresIn)*time.Second).Unix(),
|
|
userIDInt64(ui.Sub), ui.PreferredUsername,
|
|
s.now().Add(AuthCodeTTL).Unix(),
|
|
); err != nil {
|
|
s.log.Error("callback: insert auth_code", slog.String("err", err.Error()))
|
|
redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error",
|
|
"failed to persist authorization code")
|
|
return
|
|
}
|
|
_ = now
|
|
|
|
// Redirect back to the client with our code.
|
|
u, _ := url.Parse(pa.redirectURI)
|
|
rq := u.Query()
|
|
rq.Set("code", brokerCode)
|
|
if pa.clientState != "" {
|
|
rq.Set("state", pa.clientState)
|
|
}
|
|
u.RawQuery = rq.Encode()
|
|
http.Redirect(w, r, u.String(), http.StatusFound)
|
|
}
|
|
|
|
// userIDInt64 best-effort converts a string user-id (Forgejo OIDC `sub`) to
|
|
// an int64. Returns 0 if it can't parse — the username column is the
|
|
// reliable identity carrier; user_id is for log correlation.
|
|
func userIDInt64(sub string) int64 {
|
|
var n int64
|
|
for _, c := range sub {
|
|
if c < '0' || c > '9' {
|
|
return 0
|
|
}
|
|
n = n*10 + int64(c-'0')
|
|
}
|
|
return n
|
|
}
|
|
|
|
// ============================================================================
|
|
// /oauth/token — authorization_code and refresh_token grants
|
|
// ============================================================================
|
|
|
|
type tokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
Scope string `json:"scope,omitempty"`
|
|
}
|
|
|
|
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_request", "could not parse form")
|
|
return
|
|
}
|
|
switch r.PostForm.Get("grant_type") {
|
|
case "authorization_code":
|
|
s.tokenAuthCodeGrant(w, r)
|
|
case "refresh_token":
|
|
s.tokenRefreshGrant(w, r)
|
|
default:
|
|
writeOAuthError(w, http.StatusBadRequest, "unsupported_grant_type",
|
|
"only authorization_code and refresh_token are supported")
|
|
}
|
|
}
|
|
|
|
func (s *Server) tokenAuthCodeGrant(w http.ResponseWriter, r *http.Request) {
|
|
code := r.PostForm.Get("code")
|
|
clientID := r.PostForm.Get("client_id")
|
|
redirectURI := r.PostForm.Get("redirect_uri")
|
|
codeVerifier := r.PostForm.Get("code_verifier")
|
|
|
|
if code == "" || clientID == "" || redirectURI == "" || codeVerifier == "" {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_request",
|
|
"code, client_id, redirect_uri, and code_verifier are required")
|
|
return
|
|
}
|
|
|
|
// Look up the auth code and lock it via UPDATE ... used_at single-shot.
|
|
row := s.store.DB().QueryRowContext(r.Context(),
|
|
`SELECT client_id, redirect_uri, code_challenge, code_challenge_method,
|
|
scopes, forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at,
|
|
forgejo_user_id, forgejo_username, expires_at, used_at
|
|
FROM auth_codes WHERE code = ?`, code)
|
|
var (
|
|
storedClientID, storedRedirectURI, storedChallenge, storedMethod, storedScopes string
|
|
fjAccess, fjRefresh, fjUsername string
|
|
fjUserID int64
|
|
fjExpiresAt, expiresAt int64
|
|
usedAt sql.NullInt64
|
|
)
|
|
if err := row.Scan(&storedClientID, &storedRedirectURI, &storedChallenge, &storedMethod,
|
|
&storedScopes, &fjAccess, &fjRefresh, &fjExpiresAt, &fjUserID, &fjUsername,
|
|
&expiresAt, &usedAt); err != nil {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code not found")
|
|
return
|
|
}
|
|
if usedAt.Valid {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code already used")
|
|
return
|
|
}
|
|
if s.now().Unix() > expiresAt {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code expired")
|
|
return
|
|
}
|
|
if storedClientID != clientID {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "client_id mismatch")
|
|
return
|
|
}
|
|
if storedRedirectURI != redirectURI {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "redirect_uri mismatch")
|
|
return
|
|
}
|
|
if storedMethod != "S256" || !verifyPKCE(codeVerifier, storedChallenge) {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "PKCE verification failed")
|
|
return
|
|
}
|
|
|
|
// Atomically mark the code used. The WHERE clause re-checks used_at IS
|
|
// NULL so two concurrent /token requests with the same code can't both
|
|
// succeed (only one UPDATE will affect a row).
|
|
res, err := s.store.DB().ExecContext(r.Context(),
|
|
`UPDATE auth_codes SET used_at = ? WHERE code = ? AND used_at IS NULL`,
|
|
s.now().Unix(), code)
|
|
if err != nil {
|
|
s.log.Error("token: mark code used", slog.String("err", err.Error()))
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "")
|
|
return
|
|
}
|
|
if n, _ := res.RowsAffected(); n != 1 {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code already used")
|
|
return
|
|
}
|
|
|
|
// Mint broker access + refresh tokens.
|
|
accessToken := secureToken(32)
|
|
refreshToken := secureToken(32)
|
|
now := s.now().Unix()
|
|
|
|
tx, err := s.store.DB().BeginTx(r.Context(), nil)
|
|
if err != nil {
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "")
|
|
return
|
|
}
|
|
if _, err := tx.ExecContext(r.Context(),
|
|
`INSERT INTO access_tokens
|
|
(token_hash, client_id, forgejo_user_id, forgejo_username, scopes,
|
|
forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at,
|
|
expires_at, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
hashToken(accessToken), clientID, fjUserID, fjUsername, storedScopes,
|
|
fjAccess, fjRefresh, fjExpiresAt,
|
|
s.now().Add(AccessTokenTTL).Unix(), now); err != nil {
|
|
_ = tx.Rollback()
|
|
s.log.Error("token: insert access_token", slog.String("err", err.Error()))
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "")
|
|
return
|
|
}
|
|
if _, err := tx.ExecContext(r.Context(),
|
|
`INSERT INTO refresh_tokens
|
|
(token_hash, access_token_hash, client_id, expires_at, created_at)
|
|
VALUES (?, ?, ?, ?, ?)`,
|
|
hashToken(refreshToken), hashToken(accessToken), clientID,
|
|
s.now().Add(RefreshTokenTTL).Unix(), now); err != nil {
|
|
_ = tx.Rollback()
|
|
s.log.Error("token: insert refresh_token", slog.String("err", err.Error()))
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "")
|
|
return
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "")
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, tokenResponse{
|
|
AccessToken: accessToken,
|
|
TokenType: "Bearer",
|
|
ExpiresIn: int(AccessTokenTTL.Seconds()),
|
|
RefreshToken: refreshToken,
|
|
Scope: storedScopes,
|
|
})
|
|
}
|
|
|
|
func (s *Server) tokenRefreshGrant(w http.ResponseWriter, r *http.Request) {
|
|
refreshToken := r.PostForm.Get("refresh_token")
|
|
clientID := r.PostForm.Get("client_id")
|
|
if refreshToken == "" || clientID == "" {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_request",
|
|
"refresh_token and client_id are required")
|
|
return
|
|
}
|
|
|
|
rtHash := hashToken(refreshToken)
|
|
row := s.store.DB().QueryRowContext(r.Context(),
|
|
`SELECT rt.access_token_hash, rt.client_id, rt.expires_at, rt.revoked_at,
|
|
at.forgejo_user_id, at.forgejo_username, at.scopes,
|
|
at.forgejo_access_token, at.forgejo_refresh_token, at.forgejo_token_expires_at
|
|
FROM refresh_tokens rt
|
|
JOIN access_tokens at ON at.token_hash = rt.access_token_hash
|
|
WHERE rt.token_hash = ?`, rtHash)
|
|
var (
|
|
oldAccessHash, storedClientID, fjUsername, scopes string
|
|
fjAccess, fjRefresh string
|
|
fjUserID int64
|
|
expiresAt, fjExpiresAt int64
|
|
revokedAt sql.NullInt64
|
|
)
|
|
if err := row.Scan(&oldAccessHash, &storedClientID, &expiresAt, &revokedAt,
|
|
&fjUserID, &fjUsername, &scopes, &fjAccess, &fjRefresh, &fjExpiresAt); err != nil {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "refresh token not found")
|
|
return
|
|
}
|
|
if revokedAt.Valid {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "refresh token revoked")
|
|
return
|
|
}
|
|
if s.now().Unix() > expiresAt {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "refresh token expired")
|
|
return
|
|
}
|
|
if storedClientID != clientID {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "client_id mismatch")
|
|
return
|
|
}
|
|
|
|
// Mint a new access token. Refresh-token rotation: also issue a new
|
|
// refresh token and revoke the old one.
|
|
newAccess := secureToken(32)
|
|
newRefresh := secureToken(32)
|
|
now := s.now().Unix()
|
|
|
|
tx, err := s.store.DB().BeginTx(r.Context(), nil)
|
|
if err != nil {
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "")
|
|
return
|
|
}
|
|
if _, err := tx.ExecContext(r.Context(),
|
|
`INSERT INTO access_tokens
|
|
(token_hash, client_id, forgejo_user_id, forgejo_username, scopes,
|
|
forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at,
|
|
expires_at, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
hashToken(newAccess), clientID, fjUserID, fjUsername, scopes,
|
|
fjAccess, fjRefresh, fjExpiresAt,
|
|
s.now().Add(AccessTokenTTL).Unix(), now); err != nil {
|
|
_ = tx.Rollback()
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "")
|
|
return
|
|
}
|
|
if _, err := tx.ExecContext(r.Context(),
|
|
`INSERT INTO refresh_tokens
|
|
(token_hash, access_token_hash, client_id, expires_at, created_at)
|
|
VALUES (?, ?, ?, ?, ?)`,
|
|
hashToken(newRefresh), hashToken(newAccess), clientID,
|
|
s.now().Add(RefreshTokenTTL).Unix(), now); err != nil {
|
|
_ = tx.Rollback()
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "")
|
|
return
|
|
}
|
|
// Revoke the old refresh token (rotation per RFC 6749 §10.4).
|
|
if _, err := tx.ExecContext(r.Context(),
|
|
`UPDATE refresh_tokens SET revoked_at = ? WHERE token_hash = ?`,
|
|
now, rtHash); err != nil {
|
|
_ = tx.Rollback()
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "")
|
|
return
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "")
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, tokenResponse{
|
|
AccessToken: newAccess,
|
|
TokenType: "Bearer",
|
|
ExpiresIn: int(AccessTokenTTL.Seconds()),
|
|
RefreshToken: newRefresh,
|
|
Scope: scopes,
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// /oauth/revoke — RFC 7009
|
|
// ============================================================================
|
|
|
|
func (s *Server) handleRevoke(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_request", "could not parse form")
|
|
return
|
|
}
|
|
tokenStr := r.PostForm.Get("token")
|
|
hint := r.PostForm.Get("token_type_hint") // "access_token" or "refresh_token"
|
|
if tokenStr == "" {
|
|
writeOAuthError(w, http.StatusBadRequest, "invalid_request", "token is required")
|
|
return
|
|
}
|
|
|
|
hash := hashToken(tokenStr)
|
|
now := s.now().Unix()
|
|
|
|
// Try the hinted table first; fall back to the other. RFC 7009 says
|
|
// invalid tokens still get a 200 — clients shouldn't probe.
|
|
if hint != "refresh_token" {
|
|
_, _ = s.store.DB().ExecContext(r.Context(),
|
|
`UPDATE access_tokens SET revoked_at = ? WHERE token_hash = ? AND revoked_at IS NULL`,
|
|
now, hash)
|
|
}
|
|
if hint != "access_token" {
|
|
_, _ = s.store.DB().ExecContext(r.Context(),
|
|
`UPDATE refresh_tokens SET revoked_at = ? WHERE token_hash = ? AND revoked_at IS NULL`,
|
|
now, hash)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|