// Package oauth implements the broker's OAuth 2.1 authorization server // surface — what Claude.ai (and other MCP clients) talk to. User auth is // delegated to upstream Forgejo via internal/forgejo. // // Endpoints (RFC numbers in parentheses): // POST /oauth/register — RFC 7591 dynamic client registration // GET /oauth/authorize — RFC 6749 / 7636 authorize with PKCE S256 // GET /oauth/callback — Forgejo redirects back here after user consent // POST /oauth/token — authorization_code and refresh_token grants // POST /oauth/revoke — RFC 7009 token revocation // // PKCE: required, S256 only. Plain method is rejected (this is OAuth 2.1). // // Token storage: every broker-issued access/refresh token is stored as a // hex-encoded SHA-256 hash. The plaintext leaves the broker exactly once — // in the body of the /oauth/token response. // // Pending authorizations (between /authorize and /callback) live in memory // with a short TTL. A broker restart drops them, forcing the user to // re-authorize; that's an acceptable UX hit for not adding another table. package oauth import ( "context" "crypto/rand" "crypto/sha256" "crypto/subtle" "database/sql" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "log/slog" "net/http" "net/url" "strings" "sync" "time" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/forgejo" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/store" ) // AuthCodeTTL bounds how long a broker-issued authorization code stays // usable. RFC 6749 §10.5 recommends "very short" — 10 minutes matches what // most ASes use. const AuthCodeTTL = 10 * time.Minute // PendingAuthTTL caps how long an in-flight /authorize → /callback flow can // take. Anything longer is almost certainly an abandoned attempt. const PendingAuthTTL = 10 * time.Minute // AccessTokenTTL is the lifetime of broker access tokens. Refresh tokens // last considerably longer (see RefreshTokenTTL). const AccessTokenTTL = 1 * time.Hour // RefreshTokenTTL is the lifetime of broker refresh tokens. 30 days lets a // daily-active user stay logged in indefinitely while bounded enough that // theft via stale backup is time-limited. const RefreshTokenTTL = 30 * 24 * time.Hour // Server is the OAuth authorization server. Construct one with NewServer // and mount its Handler under the broker's HTTP mux. type Server struct { store *store.Store forgejo *forgejo.Client issuer string // public URL, e.g. https://mcp.example.com — never derived from headers scopes string // space-separated scope set requested from upstream Forgejo now func() time.Time log *slog.Logger mu sync.Mutex // guards pendingAuths cleanup; map ops use sync.Map natively pending sync.Map // forgejoState string → *pendingAuth stopCh chan struct{} } // Config bundles Server dependencies. type Config struct { Store *store.Store Forgejo *forgejo.Client Issuer string // required; e.g. https://mcp.example.com Scopes string // optional; space-separated; defaults to "" Now func() time.Time Log *slog.Logger } // NewServer validates the config and starts the periodic pending-auth // reaper. Stop the reaper with Server.Close. func NewServer(cfg Config) (*Server, error) { if cfg.Store == nil { return nil, errors.New("oauth: Store is required") } if cfg.Forgejo == nil { return nil, errors.New("oauth: Forgejo client is required") } if cfg.Issuer == "" { return nil, errors.New("oauth: Issuer is required") } now := cfg.Now if now == nil { now = time.Now } logger := cfg.Log if logger == nil { logger = slog.New(slog.DiscardHandler) } s := &Server{ store: cfg.Store, forgejo: cfg.Forgejo, issuer: strings.TrimRight(cfg.Issuer, "/"), scopes: cfg.Scopes, now: now, log: logger, stopCh: make(chan struct{}), } go s.reapPendingLoop() return s, nil } // Close stops the pending-auth reaper. Safe to call multiple times. func (s *Server) Close() { select { case <-s.stopCh: // already closed default: close(s.stopCh) } } // Handler returns the http.Handler exposing OAuth endpoints plus the two // discovery documents. The broker's outer mux should mount this at the // root (not under /oauth) so the .well-known paths land at their spec- // mandated location. func (s *Server) Handler() http.Handler { mux := http.NewServeMux() mux.HandleFunc("POST /oauth/register", s.handleRegister) mux.HandleFunc("GET /oauth/authorize", s.handleAuthorize) mux.HandleFunc("GET /oauth/callback", s.handleCallback) mux.HandleFunc("POST /oauth/token", s.handleToken) mux.HandleFunc("POST /oauth/revoke", s.handleRevoke) mux.HandleFunc("GET /.well-known/oauth-authorization-server", s.handleASMetadata) mux.HandleFunc("GET /.well-known/oauth-protected-resource", s.handlePRMetadata) return mux } // pendingAuth is the in-memory record of a /authorize → /callback flow. type pendingAuth struct { clientID string redirectURI string codeChallenge string codeChallengeMethod string scopes string clientState string expiresAt time.Time } // ============================================================================ // Helpers — token generation, hashing, OAuth error responses // ============================================================================ // secureToken returns a hex-encoded cryptographically-random string of the // given byte length. 32 bytes ⇒ 256 bits of entropy ⇒ 64 hex chars. func secureToken(nBytes int) string { b := make([]byte, nBytes) if _, err := rand.Read(b); err != nil { // crypto/rand failing is a system-level emergency. Panicking here is // the right move — operating without entropy is worse than crashing. panic("oauth: crypto/rand failed: " + err.Error()) } return hex.EncodeToString(b) } // hashToken returns the hex-encoded SHA-256 of the given token. Used at the // store boundary so plaintext tokens never persist. func hashToken(token string) string { sum := sha256.Sum256([]byte(token)) return hex.EncodeToString(sum[:]) } // verifyPKCE returns true iff base64url(sha256(verifier)) equals challenge. // Constant-time comparison prevents timing leaks on the verifier. func verifyPKCE(verifier, challenge string) bool { sum := sha256.Sum256([]byte(verifier)) got := base64.RawURLEncoding.EncodeToString(sum[:]) return subtle.ConstantTimeCompare([]byte(got), []byte(challenge)) == 1 } // writeJSON writes obj as JSON with the given status. Errors are logged but // not surfaced — by the time encoding could fail, headers are already out. func writeJSON(w http.ResponseWriter, status int, obj any) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Cache-Control", "no-store") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(obj) } // writeOAuthError renders an RFC 6749 §5.2 error response. func writeOAuthError(w http.ResponseWriter, status int, code, description string) { writeJSON(w, status, map[string]string{ "error": code, "error_description": description, }) } // ============================================================================ // Pending-auth reaper // ============================================================================ func (s *Server) reapPendingLoop() { t := time.NewTicker(time.Minute) defer t.Stop() for { select { case <-s.stopCh: return case <-t.C: s.reapPending() } } } func (s *Server) reapPending() { now := s.now() s.pending.Range(func(k, v any) bool { if pa, ok := v.(*pendingAuth); ok && now.After(pa.expiresAt) { s.pending.Delete(k) } return true }) } // ============================================================================ // /.well-known/* — discovery metadata (RFC 8414, RFC 9728) // ============================================================================ // // All URLs in these documents are built from the configured issuer, never // from inbound request headers. Publishing an attacker-controlled issuer // is a classic OAuth metadata-spoofing vector — defending against it has // to start at the discovery layer. type asMetadata struct { Issuer string `json:"issuer"` AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` RegistrationEndpoint string `json:"registration_endpoint"` RevocationEndpoint string `json:"revocation_endpoint,omitempty"` ResponseTypesSupported []string `json:"response_types_supported"` GrantTypesSupported []string `json:"grant_types_supported"` CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` ScopesSupported []string `json:"scopes_supported,omitempty"` } func (s *Server) handleASMetadata(w http.ResponseWriter, r *http.Request) { md := asMetadata{ Issuer: s.issuer, AuthorizationEndpoint: s.issuer + "/oauth/authorize", TokenEndpoint: s.issuer + "/oauth/token", RegistrationEndpoint: s.issuer + "/oauth/register", RevocationEndpoint: s.issuer + "/oauth/revoke", ResponseTypesSupported: []string{"code"}, GrantTypesSupported: []string{"authorization_code", "refresh_token"}, CodeChallengeMethodsSupported: []string{"S256"}, TokenEndpointAuthMethodsSupported: []string{"none"}, } if s.scopes != "" { md.ScopesSupported = strings.Fields(s.scopes) } writeJSON(w, http.StatusOK, md) } type prMetadata struct { Resource string `json:"resource"` AuthorizationServers []string `json:"authorization_servers"` BearerMethodsSupported []string `json:"bearer_methods_supported"` ScopesSupported []string `json:"scopes_supported,omitempty"` } func (s *Server) handlePRMetadata(w http.ResponseWriter, r *http.Request) { md := prMetadata{ // The protected resource is the MCP endpoint that ships in phase 5. // Publishing it here lets clients discover where to send Bearer // tokens once they've completed the OAuth dance. Resource: s.issuer + "/mcp", AuthorizationServers: []string{s.issuer}, BearerMethodsSupported: []string{"header"}, } if s.scopes != "" { md.ScopesSupported = strings.Fields(s.scopes) } writeJSON(w, http.StatusOK, md) } // ============================================================================ // /oauth/register — RFC 7591 dynamic client registration // ============================================================================ type registerRequest struct { RedirectURIs []string `json:"redirect_uris"` ClientName string `json:"client_name,omitempty"` TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` GrantTypes []string `json:"grant_types,omitempty"` ResponseTypes []string `json:"response_types,omitempty"` Scope string `json:"scope,omitempty"` } type registerResponse struct { ClientID string `json:"client_id"` ClientIDIssuedAt int64 `json:"client_id_issued_at"` RedirectURIs []string `json:"redirect_uris"` ClientName string `json:"client_name,omitempty"` TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` GrantTypes []string `json:"grant_types"` ResponseTypes []string `json:"response_types"` } func (s *Server) handleRegister(w http.ResponseWriter, r *http.Request) { var req registerRequest dec := json.NewDecoder(r.Body) dec.DisallowUnknownFields() if err := dec.Decode(&req); err != nil { writeOAuthError(w, http.StatusBadRequest, "invalid_client_metadata", "could not parse JSON: "+err.Error()) return } if len(req.RedirectURIs) == 0 { writeOAuthError(w, http.StatusBadRequest, "invalid_redirect_uri", "at least one redirect_uri is required") return } for _, ru := range req.RedirectURIs { if err := validateRedirectURI(ru); err != nil { writeOAuthError(w, http.StatusBadRequest, "invalid_redirect_uri", err.Error()) return } } clientID := secureToken(16) now := s.now().Unix() uris, _ := json.Marshal(req.RedirectURIs) meta, _ := json.Marshal(req) if _, err := s.store.DB().ExecContext(r.Context(), `INSERT INTO clients (client_id, redirect_uris, metadata_json, created_at) VALUES (?, ?, ?, ?)`, clientID, string(uris), string(meta), now, ); err != nil { s.log.Error("register: insert client", slog.String("err", err.Error())) writeOAuthError(w, http.StatusInternalServerError, "server_error", "client registration failed") return } writeJSON(w, http.StatusCreated, registerResponse{ ClientID: clientID, ClientIDIssuedAt: now, RedirectURIs: req.RedirectURIs, ClientName: req.ClientName, TokenEndpointAuthMethod: "none", // PKCE-only public clients GrantTypes: []string{"authorization_code", "refresh_token"}, ResponseTypes: []string{"code"}, }) } func validateRedirectURI(raw string) error { u, err := url.Parse(raw) if err != nil { return fmt.Errorf("redirect_uri %q: %w", raw, err) } // RFC 6749 §3.1.2.1 requires absolute URIs. We further restrict to // schemes we trust: // - https: the production case // - http: only for loopback hosts (local development) // - private-use schemes per RFC 8252 §7.1 (e.g. claude://, com.foo://) // Pseudo-schemes that allow code execution (javascript:, data:) are // rejected to keep a future naive client from rendering an attacker- // supplied URI as content. scheme := u.Scheme switch { case scheme == "https": return nil case scheme == "http": host := u.Hostname() if host == "localhost" || host == "127.0.0.1" || host == "::1" { return nil } return fmt.Errorf("redirect_uri %q: http only allowed for loopback hosts", raw) case scheme == "javascript" || scheme == "data" || scheme == "": return fmt.Errorf("redirect_uri %q: scheme %q is not allowed", raw, scheme) default: // Anything else: must be a private-use URI scheme that contains a // dot (e.g. com.example.app:/) per RFC 8252 §7.1. Single-word // schemes like "javascript" are caught above; this keeps the door // open for legitimate mobile/desktop OAuth flows without a // hardcoded allowlist. if !strings.Contains(scheme, ".") { return fmt.Errorf("redirect_uri %q: non-https scheme %q must be a reverse-DNS private-use scheme", raw, scheme) } return nil } } // ============================================================================ // /oauth/authorize — RFC 6749 + RFC 7636 PKCE // ============================================================================ func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() clientID := q.Get("client_id") redirectURI := q.Get("redirect_uri") responseType := q.Get("response_type") clientState := q.Get("state") codeChallenge := q.Get("code_challenge") codeChallengeMethod := q.Get("code_challenge_method") scope := q.Get("scope") // Errors that come BEFORE we know a valid redirect_uri: respond // directly. RFC 6749 §3.1.2.4 — if redirect_uri is invalid, do not // redirect; render an error instead. if clientID == "" { writeOAuthError(w, http.StatusBadRequest, "invalid_request", "client_id is required") return } registered, err := s.lookupClientRedirectURIs(r.Context(), clientID) if err != nil { writeOAuthError(w, http.StatusBadRequest, "invalid_client", "unknown client_id") return } if redirectURI == "" { writeOAuthError(w, http.StatusBadRequest, "invalid_request", "redirect_uri is required") return } if !redirectURIMatches(registered, redirectURI) { writeOAuthError(w, http.StatusBadRequest, "invalid_request", "redirect_uri does not match any registered URI") return } // From here on, errors can be returned via redirect to the client // (per RFC 6749 §4.1.2.1). We use that path to surface PKCE/scope // problems back to the calling app. if responseType != "code" { redirectAuthError(w, r, redirectURI, clientState, "unsupported_response_type", "only response_type=code is supported") return } if codeChallenge == "" { redirectAuthError(w, r, redirectURI, clientState, "invalid_request", "PKCE is required: code_challenge missing") return } if codeChallengeMethod != "S256" { redirectAuthError(w, r, redirectURI, clientState, "invalid_request", "only code_challenge_method=S256 is supported") return } // Stash the in-flight authorization. forgejoState is the value we'll // pass to Forgejo and read back from /callback. forgejoState := secureToken(24) s.pending.Store(forgejoState, &pendingAuth{ clientID: clientID, redirectURI: redirectURI, codeChallenge: codeChallenge, codeChallengeMethod: codeChallengeMethod, scopes: scope, clientState: clientState, expiresAt: s.now().Add(PendingAuthTTL), }) // Redirect the user-agent to Forgejo. Forgejo asks the user to consent; // on success it'll redirect back to our /oauth/callback with code+state. upstream := s.forgejo.AuthorizeURL(forgejo.AuthorizeURLOptions{ RedirectURI: s.issuer + "/oauth/callback", State: forgejoState, Scopes: s.scopes, CodeChallenge: "", // we don't pass our PKCE through; the broker is CodeChallengeMethod: "", // a confidential OAuth client of Forgejo. }) http.Redirect(w, r, upstream, http.StatusFound) } func (s *Server) lookupClientRedirectURIs(ctx context.Context, clientID string) ([]string, error) { var raw string row := s.store.DB().QueryRowContext(ctx, `SELECT redirect_uris FROM clients WHERE client_id = ?`, clientID) if err := row.Scan(&raw); err != nil { return nil, err } var uris []string if err := json.Unmarshal([]byte(raw), &uris); err != nil { return nil, err } return uris, nil } func redirectURIMatches(registered []string, candidate string) bool { for _, r := range registered { if r == candidate { return true } } return false } // redirectAuthError sends the user-agent back to the client's redirect_uri // with error=... and state=... in the query string, per RFC 6749 §4.1.2.1. func redirectAuthError(w http.ResponseWriter, r *http.Request, redirectURI, state, code, description string) { u, err := url.Parse(redirectURI) if err != nil { // Should be unreachable since redirectURIMatches already verified // it parses — but be safe. writeOAuthError(w, http.StatusBadRequest, code, description) return } q := u.Query() q.Set("error", code) if description != "" { q.Set("error_description", description) } if state != "" { q.Set("state", state) } u.RawQuery = q.Encode() http.Redirect(w, r, u.String(), http.StatusFound) } // ============================================================================ // /oauth/callback — Forgejo redirects here after user consent // ============================================================================ func (s *Server) handleCallback(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() state := q.Get("state") upstreamCode := q.Get("code") upstreamErr := q.Get("error") v, ok := s.pending.LoadAndDelete(state) if !ok { http.Error(w, "unknown or expired state; please re-authorize", http.StatusBadRequest) return } pa := v.(*pendingAuth) if s.now().After(pa.expiresAt) { http.Error(w, "authorization expired; please re-authorize", http.StatusBadRequest) return } if upstreamErr != "" { desc := q.Get("error_description") redirectAuthError(w, r, pa.redirectURI, pa.clientState, upstreamErr, desc) return } if upstreamCode == "" { redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error", "upstream returned no code") return } tok, err := s.forgejo.ExchangeCode(r.Context(), upstreamCode, "" /* no PKCE w/ Forgejo */, s.issuer+"/oauth/callback") if err != nil { s.log.Error("callback: forgejo exchange", slog.String("err", err.Error())) redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error", "upstream code exchange failed") return } ui, err := s.forgejo.FetchUserInfo(r.Context(), tok.AccessToken) if err != nil { s.log.Error("callback: fetch userinfo", slog.String("err", err.Error())) redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error", "upstream userinfo failed") return } brokerCode := secureToken(32) now := s.now().Unix() if _, err := s.store.DB().ExecContext(r.Context(), `INSERT INTO auth_codes (code, client_id, redirect_uri, code_challenge, code_challenge_method, scopes, forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at, forgejo_user_id, forgejo_username, expires_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, brokerCode, pa.clientID, pa.redirectURI, pa.codeChallenge, pa.codeChallengeMethod, pa.scopes, tok.AccessToken, tok.RefreshToken, s.now().Add(time.Duration(tok.ExpiresIn)*time.Second).Unix(), userIDInt64(ui.Sub), ui.PreferredUsername, s.now().Add(AuthCodeTTL).Unix(), ); err != nil { s.log.Error("callback: insert auth_code", slog.String("err", err.Error())) redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error", "failed to persist authorization code") return } _ = now // Redirect back to the client with our code. u, _ := url.Parse(pa.redirectURI) rq := u.Query() rq.Set("code", brokerCode) if pa.clientState != "" { rq.Set("state", pa.clientState) } u.RawQuery = rq.Encode() http.Redirect(w, r, u.String(), http.StatusFound) } // userIDInt64 best-effort converts a string user-id (Forgejo OIDC `sub`) to // an int64. Returns 0 if it can't parse — the username column is the // reliable identity carrier; user_id is for log correlation. func userIDInt64(sub string) int64 { var n int64 for _, c := range sub { if c < '0' || c > '9' { return 0 } n = n*10 + int64(c-'0') } return n } // ============================================================================ // /oauth/token — authorization_code and refresh_token grants // ============================================================================ type tokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token,omitempty"` Scope string `json:"scope,omitempty"` } func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { writeOAuthError(w, http.StatusBadRequest, "invalid_request", "could not parse form") return } switch r.PostForm.Get("grant_type") { case "authorization_code": s.tokenAuthCodeGrant(w, r) case "refresh_token": s.tokenRefreshGrant(w, r) default: writeOAuthError(w, http.StatusBadRequest, "unsupported_grant_type", "only authorization_code and refresh_token are supported") } } func (s *Server) tokenAuthCodeGrant(w http.ResponseWriter, r *http.Request) { code := r.PostForm.Get("code") clientID := r.PostForm.Get("client_id") redirectURI := r.PostForm.Get("redirect_uri") codeVerifier := r.PostForm.Get("code_verifier") if code == "" || clientID == "" || redirectURI == "" || codeVerifier == "" { writeOAuthError(w, http.StatusBadRequest, "invalid_request", "code, client_id, redirect_uri, and code_verifier are required") return } // Look up the auth code and lock it via UPDATE ... used_at single-shot. row := s.store.DB().QueryRowContext(r.Context(), `SELECT client_id, redirect_uri, code_challenge, code_challenge_method, scopes, forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at, forgejo_user_id, forgejo_username, expires_at, used_at FROM auth_codes WHERE code = ?`, code) var ( storedClientID, storedRedirectURI, storedChallenge, storedMethod, storedScopes string fjAccess, fjRefresh, fjUsername string fjUserID int64 fjExpiresAt, expiresAt int64 usedAt sql.NullInt64 ) if err := row.Scan(&storedClientID, &storedRedirectURI, &storedChallenge, &storedMethod, &storedScopes, &fjAccess, &fjRefresh, &fjExpiresAt, &fjUserID, &fjUsername, &expiresAt, &usedAt); err != nil { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code not found") return } if usedAt.Valid { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code already used") return } if s.now().Unix() > expiresAt { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code expired") return } if storedClientID != clientID { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "client_id mismatch") return } if storedRedirectURI != redirectURI { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "redirect_uri mismatch") return } if storedMethod != "S256" || !verifyPKCE(codeVerifier, storedChallenge) { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "PKCE verification failed") return } // Atomically mark the code used. The WHERE clause re-checks used_at IS // NULL so two concurrent /token requests with the same code can't both // succeed (only one UPDATE will affect a row). res, err := s.store.DB().ExecContext(r.Context(), `UPDATE auth_codes SET used_at = ? WHERE code = ? AND used_at IS NULL`, s.now().Unix(), code) if err != nil { s.log.Error("token: mark code used", slog.String("err", err.Error())) writeOAuthError(w, http.StatusInternalServerError, "server_error", "") return } if n, _ := res.RowsAffected(); n != 1 { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code already used") return } // Mint broker access + refresh tokens. accessToken := secureToken(32) refreshToken := secureToken(32) now := s.now().Unix() tx, err := s.store.DB().BeginTx(r.Context(), nil) if err != nil { writeOAuthError(w, http.StatusInternalServerError, "server_error", "") return } if _, err := tx.ExecContext(r.Context(), `INSERT INTO access_tokens (token_hash, client_id, forgejo_user_id, forgejo_username, scopes, forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at, expires_at, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, hashToken(accessToken), clientID, fjUserID, fjUsername, storedScopes, fjAccess, fjRefresh, fjExpiresAt, s.now().Add(AccessTokenTTL).Unix(), now); err != nil { _ = tx.Rollback() s.log.Error("token: insert access_token", slog.String("err", err.Error())) writeOAuthError(w, http.StatusInternalServerError, "server_error", "") return } if _, err := tx.ExecContext(r.Context(), `INSERT INTO refresh_tokens (token_hash, access_token_hash, client_id, expires_at, created_at) VALUES (?, ?, ?, ?, ?)`, hashToken(refreshToken), hashToken(accessToken), clientID, s.now().Add(RefreshTokenTTL).Unix(), now); err != nil { _ = tx.Rollback() s.log.Error("token: insert refresh_token", slog.String("err", err.Error())) writeOAuthError(w, http.StatusInternalServerError, "server_error", "") return } if err := tx.Commit(); err != nil { writeOAuthError(w, http.StatusInternalServerError, "server_error", "") return } writeJSON(w, http.StatusOK, tokenResponse{ AccessToken: accessToken, TokenType: "Bearer", ExpiresIn: int(AccessTokenTTL.Seconds()), RefreshToken: refreshToken, Scope: storedScopes, }) } func (s *Server) tokenRefreshGrant(w http.ResponseWriter, r *http.Request) { refreshToken := r.PostForm.Get("refresh_token") clientID := r.PostForm.Get("client_id") if refreshToken == "" || clientID == "" { writeOAuthError(w, http.StatusBadRequest, "invalid_request", "refresh_token and client_id are required") return } rtHash := hashToken(refreshToken) row := s.store.DB().QueryRowContext(r.Context(), `SELECT rt.access_token_hash, rt.client_id, rt.expires_at, rt.revoked_at, at.forgejo_user_id, at.forgejo_username, at.scopes, at.forgejo_access_token, at.forgejo_refresh_token, at.forgejo_token_expires_at FROM refresh_tokens rt JOIN access_tokens at ON at.token_hash = rt.access_token_hash WHERE rt.token_hash = ?`, rtHash) var ( oldAccessHash, storedClientID, fjUsername, scopes string fjAccess, fjRefresh string fjUserID int64 expiresAt, fjExpiresAt int64 revokedAt sql.NullInt64 ) if err := row.Scan(&oldAccessHash, &storedClientID, &expiresAt, &revokedAt, &fjUserID, &fjUsername, &scopes, &fjAccess, &fjRefresh, &fjExpiresAt); err != nil { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "refresh token not found") return } if revokedAt.Valid { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "refresh token revoked") return } if s.now().Unix() > expiresAt { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "refresh token expired") return } if storedClientID != clientID { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "client_id mismatch") return } // Atomically revoke the old refresh token. Two concurrent refresh // requests with the same token would otherwise both pass the read // above and each mint a fresh pair — quota duplication and a hint // to a stolen-refresh attacker that the legitimate user is also // active. Same single-shot pattern as the auth-code grant. now := s.now().Unix() res, err := s.store.DB().ExecContext(r.Context(), `UPDATE refresh_tokens SET revoked_at = ? WHERE token_hash = ? AND revoked_at IS NULL`, now, rtHash) if err != nil { writeOAuthError(w, http.StatusInternalServerError, "server_error", "") return } if n, _ := res.RowsAffected(); n != 1 { writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "refresh token already used") return } newAccess := secureToken(32) newRefresh := secureToken(32) tx, err := s.store.DB().BeginTx(r.Context(), nil) if err != nil { writeOAuthError(w, http.StatusInternalServerError, "server_error", "") return } if _, err := tx.ExecContext(r.Context(), `INSERT INTO access_tokens (token_hash, client_id, forgejo_user_id, forgejo_username, scopes, forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at, expires_at, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, hashToken(newAccess), clientID, fjUserID, fjUsername, scopes, fjAccess, fjRefresh, fjExpiresAt, s.now().Add(AccessTokenTTL).Unix(), now); err != nil { _ = tx.Rollback() writeOAuthError(w, http.StatusInternalServerError, "server_error", "") return } if _, err := tx.ExecContext(r.Context(), `INSERT INTO refresh_tokens (token_hash, access_token_hash, client_id, expires_at, created_at) VALUES (?, ?, ?, ?, ?)`, hashToken(newRefresh), hashToken(newAccess), clientID, s.now().Add(RefreshTokenTTL).Unix(), now); err != nil { _ = tx.Rollback() writeOAuthError(w, http.StatusInternalServerError, "server_error", "") return } // Old refresh token already revoked above (atomic single-shot). if err := tx.Commit(); err != nil { writeOAuthError(w, http.StatusInternalServerError, "server_error", "") return } _ = oldAccessHash // retained for potential future "revoke old access on refresh" tightening writeJSON(w, http.StatusOK, tokenResponse{ AccessToken: newAccess, TokenType: "Bearer", ExpiresIn: int(AccessTokenTTL.Seconds()), RefreshToken: newRefresh, Scope: scopes, }) } // ============================================================================ // /oauth/revoke — RFC 7009 // ============================================================================ func (s *Server) handleRevoke(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { writeOAuthError(w, http.StatusBadRequest, "invalid_request", "could not parse form") return } tokenStr := r.PostForm.Get("token") hint := r.PostForm.Get("token_type_hint") // "access_token" or "refresh_token" if tokenStr == "" { writeOAuthError(w, http.StatusBadRequest, "invalid_request", "token is required") return } hash := hashToken(tokenStr) now := s.now().Unix() // Try the hinted table first; fall back to the other. RFC 7009 says // invalid tokens still get a 200 — clients shouldn't probe. if hint != "refresh_token" { _, _ = s.store.DB().ExecContext(r.Context(), `UPDATE access_tokens SET revoked_at = ? WHERE token_hash = ? AND revoked_at IS NULL`, now, hash) } if hint != "access_token" { _, _ = s.store.DB().ExecContext(r.Context(), `UPDATE refresh_tokens SET revoked_at = ? WHERE token_hash = ? AND revoked_at IS NULL`, now, hash) } w.WriteHeader(http.StatusOK) }