From d16b18ea38bcb3863b4817cbb124c227afffb17a Mon Sep 17 00:00:00 2001 From: Ole-Morten Duesund Date: Mon, 27 Apr 2026 17:04:34 +0200 Subject: [PATCH] feat(oauth): authorization-server endpoints (forgejo-mcp-broker-pur) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements internal/oauth, the broker's OAuth 2.1 AS surface that Claude.ai (and other MCP clients) talk to. User authentication is delegated to upstream Forgejo via internal/forgejo. Endpoints: POST /oauth/register — RFC 7591 dynamic client registration GET /oauth/authorize — RFC 6749 + 7636 PKCE (S256 only) GET /oauth/callback — Forgejo redirects back here after consent POST /oauth/token — authorization_code + refresh_token grants POST /oauth/revoke — RFC 7009 Security model: - PKCE required, S256 only — plain method rejected per OAuth 2.1 - Every broker-issued access/refresh token stored as hex(sha256(plain)); plaintext leaves the broker exactly once in the /token response body - Refresh-token rotation: each refresh issues a new token pair and revokes the old refresh (RFC 6749 §10.4) - Auth-code single-use enforced atomically via UPDATE...WHERE used_at IS NULL with rows-affected check, blocking the concurrent-replay race - Issuer URL sourced from cfg.Issuer at construction time, never from inbound headers — prevents host-header injection on /.well-known metadata that ships in 2d - redirect_uri must match a registered URI exactly (no prefix/wildcard) Pending-authorization state (between /authorize and /callback) lives in an in-memory sync.Map with a 1-minute reaper goroutine. A broker restart drops them, forcing the user to re-authorize — acceptable trade-off versus introducing a fifth table. Tests: 81.0% coverage with ~20 cases across happy paths, every required- field error, PKCE failure, code-replay, refresh expiry/revocation, client-id and redirect-uri mismatches, Forgejo-side errors, and the reaper logic itself (internal test). Closes forgejo-mcp-broker-pur. The OAuth keystone is in place; phase 2c unblocks discovery (2d) and security review (2e), and combined with the existing supervisor + bridge it unblocks the session glue work in phase 5. Co-Authored-By: Claude Opus 4.7 (1M context) --- .beads/issues.jsonl | 4 +- internal/oauth/oauth.go | 810 +++++++++++++++++++ internal/oauth/oauth_internal_test.go | 107 +++ internal/oauth/oauth_test.go | 1038 +++++++++++++++++++++++++ 4 files changed, 1957 insertions(+), 2 deletions(-) create mode 100644 internal/oauth/oauth.go create mode 100644 internal/oauth/oauth_internal_test.go create mode 100644 internal/oauth/oauth_test.go diff --git a/.beads/issues.jsonl b/.beads/issues.jsonl index e311318..425856a 100644 --- a/.beads/issues.jsonl +++ b/.beads/issues.jsonl @@ -1,14 +1,14 @@ {"id":"forgejo-mcp-broker-q4x","title":"Phase 5c: idle reaper + Forgejo token rotation + child respawn","description":"Reaper (30s tick) applies idle timeout. Rotation (1-min tick) refreshes Forgejo tokens expiring \u003c2min, SIGTERMs child, respawns on next request (design.md §6). Token revocation tears down sessions.","acceptance_criteria":"Clock-injected tests: idle kill, rotation triggers respawn, revocation tears down sessions. Smoke test: 20 concurrent sessions for 10min with mid-test rotations.","status":"open","priority":1,"issue_type":"task","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:18Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-24T15:45:18Z","dependencies":[{"issue_id":"forgejo-mcp-broker-q4x","depends_on_id":"forgejo-mcp-broker-pur","type":"blocks","created_at":"2026-04-24T17:45:31Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-q4x","depends_on_id":"forgejo-mcp-broker-t81","type":"blocks","created_at":"2026-04-24T17:45:31Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":2,"dependent_count":1,"comment_count":0} {"id":"forgejo-mcp-broker-ytw","title":"Phase 5b: bearer-token middleware on /mcp","description":"Middleware reads Authorization: Bearer \u003cmcp_token\u003e, resolves via store, attaches Forgejo access token to request context. 401 for missing/expired/revoked.","acceptance_criteria":"Table-driven tests: missing header, malformed, unknown token, expired, revoked, valid. Valid-token path puts Forgejo token on ctx via typed key.","status":"open","priority":1,"issue_type":"task","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:18Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-24T15:45:18Z","dependencies":[{"issue_id":"forgejo-mcp-broker-ytw","depends_on_id":"forgejo-mcp-broker-pur","type":"blocks","created_at":"2026-04-24T17:45:30Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":1,"dependent_count":1,"comment_count":0} {"id":"forgejo-mcp-broker-t81","title":"Phase 5a: session registry + spawn-on-initialize","description":"Map Mcp-Session-Id -\u003e supervisor.Child + user metadata. On first initialize for unknown sid, spawn forgejo-mcp with user's Forgejo token in env, bind to bridge. LastActive bumped per request.","acceptance_criteria":"Tests with fake supervisor + fake bridge cover: spawn-on-initialize, reuse for subsequent messages, unknown-sid returns 410, max-sessions cap enforced.","status":"open","priority":1,"issue_type":"task","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:17Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-24T15:45:17Z","dependencies":[{"issue_id":"forgejo-mcp-broker-t81","depends_on_id":"forgejo-mcp-broker-am1","type":"blocks","created_at":"2026-04-24T17:45:29Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-t81","depends_on_id":"forgejo-mcp-broker-pur","type":"blocks","created_at":"2026-04-24T17:45:30Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-t81","depends_on_id":"forgejo-mcp-broker-zuq","type":"blocks","created_at":"2026-04-24T17:45:28Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":3,"dependent_count":2,"comment_count":0} -{"id":"forgejo-mcp-broker-xot","title":"Phase 4b: bridge integration test against real forgejo-mcp","description":"Drive the bridge with initialize -\u003e tools/list -\u003e tools/call get_forgejo_mcp_server_version against a real forgejo-mcp subprocess. Validates the opaque-pipe assumption.","acceptance_criteria":"Full handshake, tools/list returns expected set, tools/call returns a version string. Tagged as integration test if runtime exceeds 2s.","status":"in_progress","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:16Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T14:10:04Z","started_at":"2026-04-27T14:10:04Z","dependencies":[{"issue_id":"forgejo-mcp-broker-xot","depends_on_id":"forgejo-mcp-broker-am1","type":"blocks","created_at":"2026-04-24T17:45:28Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":1,"dependent_count":0,"comment_count":0} +{"id":"forgejo-mcp-broker-xot","title":"Phase 4b: bridge integration test against real forgejo-mcp","description":"Drive the bridge with initialize -\u003e tools/list -\u003e tools/call get_forgejo_mcp_server_version against a real forgejo-mcp subprocess. Validates the opaque-pipe assumption.","acceptance_criteria":"Full handshake, tools/list returns expected set, tools/call returns a version string. Tagged as integration test if runtime exceeds 2s.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:16Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T14:28:39Z","started_at":"2026-04-27T14:10:04Z","closed_at":"2026-04-27T14:28:39Z","close_reason":"Bridge integration test passes against real forgejo-mcp 2.2.0: MCP handshake (initialize → notifications/initialized → tools/list → tools/call) round-trips through bridge cleanly. Fake Forgejo covers /api/v1/version and /api/v1/user probes. Phase 4 complete.","dependencies":[{"issue_id":"forgejo-mcp-broker-xot","depends_on_id":"forgejo-mcp-broker-am1","type":"blocks","created_at":"2026-04-24T17:45:28Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":1,"dependent_count":0,"comment_count":0} {"id":"forgejo-mcp-broker-31t","title":"Phase 3b: supervisor stress tests (FD/goroutine/zombie leak detection)","description":"1000 spawn/stop cycles under -race. Verify no FD leak, no goroutine leak (go.uber.org/goleak), no zombies (wait4 returns ECHILD when idle).","acceptance_criteria":"Cycle test passes under -race. FD count stable within a small constant. goleak detects no extra goroutines after test.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:15Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T14:04:42Z","started_at":"2026-04-27T12:00:32Z","closed_at":"2026-04-27T14:04:42Z","close_reason":"Stress tests in place: 1000-cycle spawn/reap and 200-cycle Stop both clean under -race; FD/goroutine/zombie deltas all single-digit. Driver: /bin/true and /bin/cat (helper-process recursion at scale exposed an unrelated Go pidfd interaction). Supervisor now defensively closes pipe handles post-Wait.","dependencies":[{"issue_id":"forgejo-mcp-broker-31t","depends_on_id":"forgejo-mcp-broker-zuq","type":"blocks","created_at":"2026-04-24T17:45:26Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":1,"dependent_count":0,"comment_count":0} {"id":"forgejo-mcp-broker-am1","title":"Phase 4a: internal/bridge JSON-RPC pipe + SSE writer","description":"Given a supervisor.Child: inbound HTTP JSON -\u003e newline-framed stdin; stdout lines -\u003e SSE frames. Handle client disconnect without killing the child.","acceptance_criteria":"Unit tests with mock Child that echoes: request/response round trip, multiple concurrent requests with correct id routing, client disconnect mid-stream.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:15Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T11:59:35Z","started_at":"2026-04-27T11:56:15Z","closed_at":"2026-04-27T11:59:35Z","close_reason":"Bridge shipped: per-id routing, SSE responses for request/reply messages, 204 for notifications, structured 4xx/5xx for malformed input. Decoupled from supervisor (takes pipes directly) for clean testing via io.Pipe. 90.0% coverage.","dependencies":[{"issue_id":"forgejo-mcp-broker-am1","depends_on_id":"forgejo-mcp-broker-zuq","type":"blocks","created_at":"2026-04-24T17:45:27Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":1,"dependent_count":2,"comment_count":0} {"id":"forgejo-mcp-broker-wgo","title":"Phase 2e: OAuth security review + attack-path tests","description":"Phase 2 exit gate. Review every handler for classic OAuth vulns (open redirect, code replay, mix-up, token leak in logs, host spoofing). Add at least one test per attack class. Update design.md §8 with findings.","acceptance_criteria":"Review checklist documented. Tests added for: PKCE mismatch, stale code, token absent from log attributes, bad redirect_uri, mismatched state, replay of used code.","status":"open","priority":1,"issue_type":"task","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:14Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-24T15:45:14Z","dependencies":[{"issue_id":"forgejo-mcp-broker-wgo","depends_on_id":"forgejo-mcp-broker-b2o","type":"blocks","created_at":"2026-04-24T17:45:26Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-wgo","depends_on_id":"forgejo-mcp-broker-pur","type":"blocks","created_at":"2026-04-24T17:45:25Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":2,"dependent_count":0,"comment_count":0} {"id":"forgejo-mcp-broker-zuq","title":"Phase 3a: internal/supervisor managed stdio subprocess","description":"Child type: Start, Stop(ctx) with SIGTERM -\u003e grace -\u003e SIGKILL, Wait+reap goroutine (no zombies), stderr drainer with prefix. Protocol-agnostic.","acceptance_criteria":"Unit tests against an echo-loop helper: round trip, graceful stop, kill-after-grace, child-exits-on-own detection, stderr capture. Manual spawn of real forgejo-mcp --transport stdio works.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:14Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T11:41:07Z","started_at":"2026-04-27T11:32:54Z","closed_at":"2026-04-27T11:41:07Z","close_reason":"internal/supervisor shipped: Start/Stop/Done/ExitErr/Pid, SIGTERM-\u003egrace-\u003eSIGKILL escalation, mandatory wait-and-reap. Test uses TestMain helper-process pattern; coverage 89.6% on the testable surface.","dependency_count":0,"dependent_count":3,"comment_count":0} {"id":"forgejo-mcp-broker-b2o","title":"Phase 2d: OAuth discovery endpoints (/.well-known/*)","description":"GET /.well-known/oauth-protected-resource and /.well-known/oauth-authorization-server. Issuer URLs MUST derive from cfg.PublicURL, never inbound headers (host-header attack defense per design.md §8).","acceptance_criteria":"Responses validate against RFC 8414/9728 shapes. Issuer URL sourced from config only. supported_scopes matches cfg.ForgejoOAuthScopes.","status":"open","priority":1,"issue_type":"task","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:13Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-24T15:45:13Z","dependencies":[{"issue_id":"forgejo-mcp-broker-b2o","depends_on_id":"forgejo-mcp-broker-pur","type":"blocks","created_at":"2026-04-24T17:45:25Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":1,"dependent_count":1,"comment_count":0} {"id":"forgejo-mcp-broker-b9i","title":"Phase 2b: internal/forgejo OAuth client","description":"Broker-side OAuth client for upstream Forgejo: authorize URL builder, code-to-token exchange, refresh_token grant, userinfo fetch, revoke. Used by AS callback and refresh machinery. Stateless; caller owns persistence.","acceptance_criteria":"Unit tests with httptest.Server fake Forgejo cover each grant plus error paths (wrong code, expired refresh, revoked token). No state persisted in this package.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:12Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T11:31:27Z","started_at":"2026-04-27T11:29:17Z","closed_at":"2026-04-27T11:31:27Z","close_reason":"internal/forgejo shipped: AuthorizeURL, ExchangeCode, Refresh, FetchUserInfo. Structured *forgejo.Error for OAuth failures (errors.As-friendly). 95.1% coverage. Stateless — caller owns persistence. Revocation deferred since upstream Forgejo lacks the endpoint.","dependency_count":0,"dependent_count":1,"comment_count":0} -{"id":"forgejo-mcp-broker-pur","title":"Phase 2c: internal/oauth AS endpoints (register, authorize, callback, token, revoke)","description":"Five OAuth handlers per design.md §4.1. RFC 7591 DCR with ephemeral client IDs, authorize -\u003e Forgejo delegation, callback minting broker auth codes, token exchange with SHA-256 hashing at rest, revoke. PKCE S256 required.","acceptance_criteria":"End-to-end curl walkthrough from plan.md phase 2 passes. All tokens hashed at rest. Auth codes single-use, 10-min TTL. Rejects missing PKCE, non-S256, wrong verifier, expired codes/tokens. Handler coverage \u003e=80%.","status":"open","priority":1,"issue_type":"task","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:12Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-24T15:45:12Z","dependencies":[{"issue_id":"forgejo-mcp-broker-pur","depends_on_id":"forgejo-mcp-broker-b9i","type":"blocks","created_at":"2026-04-24T17:45:24Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-pur","depends_on_id":"forgejo-mcp-broker-cpb","type":"blocks","created_at":"2026-04-24T17:45:24Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":2,"dependent_count":5,"comment_count":0} +{"id":"forgejo-mcp-broker-pur","title":"Phase 2c: internal/oauth AS endpoints (register, authorize, callback, token, revoke)","description":"Five OAuth handlers per design.md §4.1. RFC 7591 DCR with ephemeral client IDs, authorize -\u003e Forgejo delegation, callback minting broker auth codes, token exchange with SHA-256 hashing at rest, revoke. PKCE S256 required.","acceptance_criteria":"End-to-end curl walkthrough from plan.md phase 2 passes. All tokens hashed at rest. Auth codes single-use, 10-min TTL. Rejects missing PKCE, non-S256, wrong verifier, expired codes/tokens. Handler coverage \u003e=80%.","status":"in_progress","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:12Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T14:30:02Z","started_at":"2026-04-27T14:30:02Z","dependencies":[{"issue_id":"forgejo-mcp-broker-pur","depends_on_id":"forgejo-mcp-broker-b9i","type":"blocks","created_at":"2026-04-24T17:45:24Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-pur","depends_on_id":"forgejo-mcp-broker-cpb","type":"blocks","created_at":"2026-04-24T17:45:24Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":2,"dependent_count":5,"comment_count":0} {"id":"forgejo-mcp-broker-cpb","title":"Phase 2a: OAuth tables migration","description":"Add migrations/0002_oauth_tables.sql creating clients, auth_codes, access_tokens, refresh_tokens per design.md §4.2. Broker tokens stored as SHA-256 hashes; Forgejo tokens cleartext (subprocess spawning requires them). See plan.md phase 2.","acceptance_criteria":"Migration applies on a fresh DB and is idempotent on reopen. Schema matches design.md §4.2. Tests verify every table and key column exists.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T15:45:04Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-27T11:28:20Z","started_at":"2026-04-27T11:26:17Z","closed_at":"2026-04-27T11:28:20Z","close_reason":"0002_oauth_tables.sql shipped: clients/auth_codes/access_tokens/refresh_tokens with cascading FKs, indexes on hot-path columns, and an oauth_schema_version marker. PRAGMA-driven tests verify columns; FK cascade tested in both directions.","dependency_count":0,"dependent_count":1,"comment_count":0} {"id":"forgejo-mcp-broker-8ei","title":"Phase 1: internal/httpserver with /healthz and graceful shutdown","description":"Implement internal/httpserver: constructs a *http.Server bound to cfg.Listen, mounts GET /healthz (returns 200 with JSON build-info from the build-info package, including version, git revision, build date, and current store status), handles SIGTERM/SIGINT by initiating graceful shutdown with a configurable deadline (default 10s). Uses log/slog for structured JSON logs. Exposes a Run(ctx) error method that blocks until shutdown completes.","acceptance_criteria":"go test ./internal/httpserver passes; GET /healthz returns expected JSON; sending SIGTERM causes Run to return nil within 2 seconds after in-flight requests complete; slow in-flight request is allowed to finish within the deadline, then forcibly closed.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T14:46:20Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-24T15:26:43Z","started_at":"2026-04-24T15:24:09Z","closed_at":"2026-04-24T15:26:43Z","close_reason":"httpserver shipped: /healthz with store probe, graceful shutdown with force-close fallback, ExtraHandler extension point. 97.9% coverage. internal/log also implemented in the same commit (100% coverage).","dependencies":[{"issue_id":"forgejo-mcp-broker-8ei","depends_on_id":"forgejo-mcp-broker-n84","type":"blocks","created_at":"2026-04-24T16:46:19Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":1,"dependent_count":1,"comment_count":0} {"id":"forgejo-mcp-broker-t37","title":"Phase 1: wire cmd/broker/main.go and integration test","description":"Final phase 1 task: wire config → log → store → httpserver in cmd/broker/main.go. Parse config, init slog, open store, start httpserver, wait for shutdown signal, close store, exit. Add an integration test under cmd/broker/ that builds the binary, runs it with a valid env + temp store path, curls /healthz, sends SIGTERM, verifies clean exit within 2s. This is the acceptance gate for phase 1.","acceptance_criteria":"make build; make test (incl. integration) pass; running the binary with missing config fails with a clear error; running with valid config serves /healthz; SIGTERM shuts down cleanly within 2s; /healthz JSON includes version, git revision, build date, and store OK status.","status":"closed","priority":1,"issue_type":"task","assignee":"Ole-Morten Duesund","owner":"olemd@glemt.net","created_at":"2026-04-24T14:46:20Z","created_by":"Ole-Morten Duesund","updated_at":"2026-04-24T15:29:44Z","started_at":"2026-04-24T15:27:58Z","closed_at":"2026-04-24T15:29:44Z","close_reason":"Main wired, signal.NotifyContext triggers shutdown cascade, integration tests green. Phase 1 complete: binary starts with valid config, serves /healthz JSON, shuts down cleanly on SIGTERM within 2s.","dependencies":[{"issue_id":"forgejo-mcp-broker-t37","depends_on_id":"forgejo-mcp-broker-8ei","type":"blocks","created_at":"2026-04-24T16:48:29Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-t37","depends_on_id":"forgejo-mcp-broker-9jh","type":"blocks","created_at":"2026-04-24T16:48:29Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-t37","depends_on_id":"forgejo-mcp-broker-9nq","type":"blocks","created_at":"2026-04-24T16:48:28Z","created_by":"Ole-Morten Duesund","metadata":"{}"},{"issue_id":"forgejo-mcp-broker-t37","depends_on_id":"forgejo-mcp-broker-n84","type":"blocks","created_at":"2026-04-24T16:48:28Z","created_by":"Ole-Morten Duesund","metadata":"{}"}],"dependency_count":4,"dependent_count":0,"comment_count":0} diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go new file mode 100644 index 0000000..b77a6fa --- /dev/null +++ b/internal/oauth/oauth.go @@ -0,0 +1,810 @@ +// Package oauth implements the broker's OAuth 2.1 authorization server +// surface — what Claude.ai (and other MCP clients) talk to. User auth is +// delegated to upstream Forgejo via internal/forgejo. +// +// Endpoints (RFC numbers in parentheses): +// POST /oauth/register — RFC 7591 dynamic client registration +// GET /oauth/authorize — RFC 6749 / 7636 authorize with PKCE S256 +// GET /oauth/callback — Forgejo redirects back here after user consent +// POST /oauth/token — authorization_code and refresh_token grants +// POST /oauth/revoke — RFC 7009 token revocation +// +// PKCE: required, S256 only. Plain method is rejected (this is OAuth 2.1). +// +// Token storage: every broker-issued access/refresh token is stored as a +// hex-encoded SHA-256 hash. The plaintext leaves the broker exactly once — +// in the body of the /oauth/token response. +// +// Pending authorizations (between /authorize and /callback) live in memory +// with a short TTL. A broker restart drops them, forcing the user to +// re-authorize; that's an acceptable UX hit for not adding another table. +package oauth + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "database/sql" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "kode.naiv.no/olemd/forgejo-mcp-broker/internal/forgejo" + "kode.naiv.no/olemd/forgejo-mcp-broker/internal/store" +) + +// AuthCodeTTL bounds how long a broker-issued authorization code stays +// usable. RFC 6749 §10.5 recommends "very short" — 10 minutes matches what +// most ASes use. +const AuthCodeTTL = 10 * time.Minute + +// PendingAuthTTL caps how long an in-flight /authorize → /callback flow can +// take. Anything longer is almost certainly an abandoned attempt. +const PendingAuthTTL = 10 * time.Minute + +// AccessTokenTTL is the lifetime of broker access tokens. Refresh tokens +// last considerably longer (see RefreshTokenTTL). +const AccessTokenTTL = 1 * time.Hour + +// RefreshTokenTTL is the lifetime of broker refresh tokens. 30 days lets a +// daily-active user stay logged in indefinitely while bounded enough that +// theft via stale backup is time-limited. +const RefreshTokenTTL = 30 * 24 * time.Hour + +// Server is the OAuth authorization server. Construct one with NewServer +// and mount its Handler under the broker's HTTP mux. +type Server struct { + store *store.Store + forgejo *forgejo.Client + issuer string // public URL, e.g. https://mcp.example.com — never derived from headers + scopes string // space-separated scope set requested from upstream Forgejo + now func() time.Time + log *slog.Logger + mu sync.Mutex // guards pendingAuths cleanup; map ops use sync.Map natively + pending sync.Map // forgejoState string → *pendingAuth + stopCh chan struct{} +} + +// Config bundles Server dependencies. +type Config struct { + Store *store.Store + Forgejo *forgejo.Client + Issuer string // required; e.g. https://mcp.example.com + Scopes string // optional; space-separated; defaults to "" + Now func() time.Time + Log *slog.Logger +} + +// NewServer validates the config and starts the periodic pending-auth +// reaper. Stop the reaper with Server.Close. +func NewServer(cfg Config) (*Server, error) { + if cfg.Store == nil { + return nil, errors.New("oauth: Store is required") + } + if cfg.Forgejo == nil { + return nil, errors.New("oauth: Forgejo client is required") + } + if cfg.Issuer == "" { + return nil, errors.New("oauth: Issuer is required") + } + now := cfg.Now + if now == nil { + now = time.Now + } + logger := cfg.Log + if logger == nil { + logger = slog.New(slog.DiscardHandler) + } + s := &Server{ + store: cfg.Store, + forgejo: cfg.Forgejo, + issuer: strings.TrimRight(cfg.Issuer, "/"), + scopes: cfg.Scopes, + now: now, + log: logger, + stopCh: make(chan struct{}), + } + go s.reapPendingLoop() + return s, nil +} + +// Close stops the pending-auth reaper. Safe to call multiple times. +func (s *Server) Close() { + select { + case <-s.stopCh: + // already closed + default: + close(s.stopCh) + } +} + +// Handler returns the http.Handler exposing all five OAuth endpoints. +func (s *Server) Handler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("POST /oauth/register", s.handleRegister) + mux.HandleFunc("GET /oauth/authorize", s.handleAuthorize) + mux.HandleFunc("GET /oauth/callback", s.handleCallback) + mux.HandleFunc("POST /oauth/token", s.handleToken) + mux.HandleFunc("POST /oauth/revoke", s.handleRevoke) + return mux +} + +// pendingAuth is the in-memory record of a /authorize → /callback flow. +type pendingAuth struct { + clientID string + redirectURI string + codeChallenge string + codeChallengeMethod string + scopes string + clientState string + expiresAt time.Time +} + +// ============================================================================ +// Helpers — token generation, hashing, OAuth error responses +// ============================================================================ + +// secureToken returns a hex-encoded cryptographically-random string of the +// given byte length. 32 bytes ⇒ 256 bits of entropy ⇒ 64 hex chars. +func secureToken(nBytes int) string { + b := make([]byte, nBytes) + if _, err := rand.Read(b); err != nil { + // crypto/rand failing is a system-level emergency. Panicking here is + // the right move — operating without entropy is worse than crashing. + panic("oauth: crypto/rand failed: " + err.Error()) + } + return hex.EncodeToString(b) +} + +// hashToken returns the hex-encoded SHA-256 of the given token. Used at the +// store boundary so plaintext tokens never persist. +func hashToken(token string) string { + sum := sha256.Sum256([]byte(token)) + return hex.EncodeToString(sum[:]) +} + +// verifyPKCE returns true iff base64url(sha256(verifier)) equals challenge. +// Constant-time comparison prevents timing leaks on the verifier. +func verifyPKCE(verifier, challenge string) bool { + sum := sha256.Sum256([]byte(verifier)) + got := base64.RawURLEncoding.EncodeToString(sum[:]) + return subtle.ConstantTimeCompare([]byte(got), []byte(challenge)) == 1 +} + +// writeJSON writes obj as JSON with the given status. Errors are logged but +// not surfaced — by the time encoding could fail, headers are already out. +func writeJSON(w http.ResponseWriter, status int, obj any) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(obj) +} + +// writeOAuthError renders an RFC 6749 §5.2 error response. +func writeOAuthError(w http.ResponseWriter, status int, code, description string) { + writeJSON(w, status, map[string]string{ + "error": code, + "error_description": description, + }) +} + +// ============================================================================ +// Pending-auth reaper +// ============================================================================ + +func (s *Server) reapPendingLoop() { + t := time.NewTicker(time.Minute) + defer t.Stop() + for { + select { + case <-s.stopCh: + return + case <-t.C: + s.reapPending() + } + } +} + +func (s *Server) reapPending() { + now := s.now() + s.pending.Range(func(k, v any) bool { + if pa, ok := v.(*pendingAuth); ok && now.After(pa.expiresAt) { + s.pending.Delete(k) + } + return true + }) +} + +// ============================================================================ +// /oauth/register — RFC 7591 dynamic client registration +// ============================================================================ + +type registerRequest struct { + RedirectURIs []string `json:"redirect_uris"` + ClientName string `json:"client_name,omitempty"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + GrantTypes []string `json:"grant_types,omitempty"` + ResponseTypes []string `json:"response_types,omitempty"` + Scope string `json:"scope,omitempty"` +} + +type registerResponse struct { + ClientID string `json:"client_id"` + ClientIDIssuedAt int64 `json:"client_id_issued_at"` + RedirectURIs []string `json:"redirect_uris"` + ClientName string `json:"client_name,omitempty"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` +} + +func (s *Server) handleRegister(w http.ResponseWriter, r *http.Request) { + var req registerRequest + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(&req); err != nil { + writeOAuthError(w, http.StatusBadRequest, "invalid_client_metadata", + "could not parse JSON: "+err.Error()) + return + } + if len(req.RedirectURIs) == 0 { + writeOAuthError(w, http.StatusBadRequest, "invalid_redirect_uri", + "at least one redirect_uri is required") + return + } + for _, ru := range req.RedirectURIs { + if err := validateRedirectURI(ru); err != nil { + writeOAuthError(w, http.StatusBadRequest, "invalid_redirect_uri", err.Error()) + return + } + } + + clientID := secureToken(16) + now := s.now().Unix() + + uris, _ := json.Marshal(req.RedirectURIs) + meta, _ := json.Marshal(req) + + if _, err := s.store.DB().ExecContext(r.Context(), + `INSERT INTO clients (client_id, redirect_uris, metadata_json, created_at) + VALUES (?, ?, ?, ?)`, + clientID, string(uris), string(meta), now, + ); err != nil { + s.log.Error("register: insert client", slog.String("err", err.Error())) + writeOAuthError(w, http.StatusInternalServerError, "server_error", "client registration failed") + return + } + + writeJSON(w, http.StatusCreated, registerResponse{ + ClientID: clientID, + ClientIDIssuedAt: now, + RedirectURIs: req.RedirectURIs, + ClientName: req.ClientName, + TokenEndpointAuthMethod: "none", // PKCE-only public clients + GrantTypes: []string{"authorization_code", "refresh_token"}, + ResponseTypes: []string{"code"}, + }) +} + +func validateRedirectURI(raw string) error { + u, err := url.Parse(raw) + if err != nil { + return fmt.Errorf("redirect_uri %q: %w", raw, err) + } + if u.Scheme == "" { + return fmt.Errorf("redirect_uri %q: missing scheme", raw) + } + // RFC 6749 §3.1.2.1 requires absolute URIs; we additionally require + // http/https or claude.ai's documented custom scheme. Accept anything + // non-empty for now; tighten later if needed. + return nil +} + +// ============================================================================ +// /oauth/authorize — RFC 6749 + RFC 7636 PKCE +// ============================================================================ + +func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + clientID := q.Get("client_id") + redirectURI := q.Get("redirect_uri") + responseType := q.Get("response_type") + clientState := q.Get("state") + codeChallenge := q.Get("code_challenge") + codeChallengeMethod := q.Get("code_challenge_method") + scope := q.Get("scope") + + // Errors that come BEFORE we know a valid redirect_uri: respond + // directly. RFC 6749 §3.1.2.4 — if redirect_uri is invalid, do not + // redirect; render an error instead. + + if clientID == "" { + writeOAuthError(w, http.StatusBadRequest, "invalid_request", "client_id is required") + return + } + registered, err := s.lookupClientRedirectURIs(r.Context(), clientID) + if err != nil { + writeOAuthError(w, http.StatusBadRequest, "invalid_client", "unknown client_id") + return + } + if redirectURI == "" { + writeOAuthError(w, http.StatusBadRequest, "invalid_request", "redirect_uri is required") + return + } + if !redirectURIMatches(registered, redirectURI) { + writeOAuthError(w, http.StatusBadRequest, "invalid_request", + "redirect_uri does not match any registered URI") + return + } + + // From here on, errors can be returned via redirect to the client + // (per RFC 6749 §4.1.2.1). We use that path to surface PKCE/scope + // problems back to the calling app. + + if responseType != "code" { + redirectAuthError(w, r, redirectURI, clientState, "unsupported_response_type", + "only response_type=code is supported") + return + } + if codeChallenge == "" { + redirectAuthError(w, r, redirectURI, clientState, "invalid_request", + "PKCE is required: code_challenge missing") + return + } + if codeChallengeMethod != "S256" { + redirectAuthError(w, r, redirectURI, clientState, "invalid_request", + "only code_challenge_method=S256 is supported") + return + } + + // Stash the in-flight authorization. forgejoState is the value we'll + // pass to Forgejo and read back from /callback. + forgejoState := secureToken(24) + s.pending.Store(forgejoState, &pendingAuth{ + clientID: clientID, + redirectURI: redirectURI, + codeChallenge: codeChallenge, + codeChallengeMethod: codeChallengeMethod, + scopes: scope, + clientState: clientState, + expiresAt: s.now().Add(PendingAuthTTL), + }) + + // Redirect the user-agent to Forgejo. Forgejo asks the user to consent; + // on success it'll redirect back to our /oauth/callback with code+state. + upstream := s.forgejo.AuthorizeURL(forgejo.AuthorizeURLOptions{ + RedirectURI: s.issuer + "/oauth/callback", + State: forgejoState, + Scopes: s.scopes, + CodeChallenge: "", // we don't pass our PKCE through; the broker is + CodeChallengeMethod: "", // a confidential OAuth client of Forgejo. + }) + http.Redirect(w, r, upstream, http.StatusFound) +} + +func (s *Server) lookupClientRedirectURIs(ctx context.Context, clientID string) ([]string, error) { + var raw string + row := s.store.DB().QueryRowContext(ctx, + `SELECT redirect_uris FROM clients WHERE client_id = ?`, clientID) + if err := row.Scan(&raw); err != nil { + return nil, err + } + var uris []string + if err := json.Unmarshal([]byte(raw), &uris); err != nil { + return nil, err + } + return uris, nil +} + +func redirectURIMatches(registered []string, candidate string) bool { + for _, r := range registered { + if r == candidate { + return true + } + } + return false +} + +// redirectAuthError sends the user-agent back to the client's redirect_uri +// with error=... and state=... in the query string, per RFC 6749 §4.1.2.1. +func redirectAuthError(w http.ResponseWriter, r *http.Request, redirectURI, state, code, description string) { + u, err := url.Parse(redirectURI) + if err != nil { + // Should be unreachable since redirectURIMatches already verified + // it parses — but be safe. + writeOAuthError(w, http.StatusBadRequest, code, description) + return + } + q := u.Query() + q.Set("error", code) + if description != "" { + q.Set("error_description", description) + } + if state != "" { + q.Set("state", state) + } + u.RawQuery = q.Encode() + http.Redirect(w, r, u.String(), http.StatusFound) +} + +// ============================================================================ +// /oauth/callback — Forgejo redirects here after user consent +// ============================================================================ + +func (s *Server) handleCallback(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + state := q.Get("state") + upstreamCode := q.Get("code") + upstreamErr := q.Get("error") + + v, ok := s.pending.LoadAndDelete(state) + if !ok { + http.Error(w, "unknown or expired state; please re-authorize", http.StatusBadRequest) + return + } + pa := v.(*pendingAuth) + if s.now().After(pa.expiresAt) { + http.Error(w, "authorization expired; please re-authorize", http.StatusBadRequest) + return + } + + if upstreamErr != "" { + desc := q.Get("error_description") + redirectAuthError(w, r, pa.redirectURI, pa.clientState, upstreamErr, desc) + return + } + if upstreamCode == "" { + redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error", + "upstream returned no code") + return + } + + tok, err := s.forgejo.ExchangeCode(r.Context(), upstreamCode, "" /* no PKCE w/ Forgejo */, s.issuer+"/oauth/callback") + if err != nil { + s.log.Error("callback: forgejo exchange", slog.String("err", err.Error())) + redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error", + "upstream code exchange failed") + return + } + + ui, err := s.forgejo.FetchUserInfo(r.Context(), tok.AccessToken) + if err != nil { + s.log.Error("callback: fetch userinfo", slog.String("err", err.Error())) + redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error", + "upstream userinfo failed") + return + } + + brokerCode := secureToken(32) + now := s.now().Unix() + + if _, err := s.store.DB().ExecContext(r.Context(), + `INSERT INTO auth_codes + (code, client_id, redirect_uri, code_challenge, code_challenge_method, + scopes, forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at, + forgejo_user_id, forgejo_username, expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + brokerCode, pa.clientID, pa.redirectURI, pa.codeChallenge, pa.codeChallengeMethod, + pa.scopes, tok.AccessToken, tok.RefreshToken, + s.now().Add(time.Duration(tok.ExpiresIn)*time.Second).Unix(), + userIDInt64(ui.Sub), ui.PreferredUsername, + s.now().Add(AuthCodeTTL).Unix(), + ); err != nil { + s.log.Error("callback: insert auth_code", slog.String("err", err.Error())) + redirectAuthError(w, r, pa.redirectURI, pa.clientState, "server_error", + "failed to persist authorization code") + return + } + _ = now + + // Redirect back to the client with our code. + u, _ := url.Parse(pa.redirectURI) + rq := u.Query() + rq.Set("code", brokerCode) + if pa.clientState != "" { + rq.Set("state", pa.clientState) + } + u.RawQuery = rq.Encode() + http.Redirect(w, r, u.String(), http.StatusFound) +} + +// userIDInt64 best-effort converts a string user-id (Forgejo OIDC `sub`) to +// an int64. Returns 0 if it can't parse — the username column is the +// reliable identity carrier; user_id is for log correlation. +func userIDInt64(sub string) int64 { + var n int64 + for _, c := range sub { + if c < '0' || c > '9' { + return 0 + } + n = n*10 + int64(c-'0') + } + return n +} + +// ============================================================================ +// /oauth/token — authorization_code and refresh_token grants +// ============================================================================ + +type tokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + writeOAuthError(w, http.StatusBadRequest, "invalid_request", "could not parse form") + return + } + switch r.PostForm.Get("grant_type") { + case "authorization_code": + s.tokenAuthCodeGrant(w, r) + case "refresh_token": + s.tokenRefreshGrant(w, r) + default: + writeOAuthError(w, http.StatusBadRequest, "unsupported_grant_type", + "only authorization_code and refresh_token are supported") + } +} + +func (s *Server) tokenAuthCodeGrant(w http.ResponseWriter, r *http.Request) { + code := r.PostForm.Get("code") + clientID := r.PostForm.Get("client_id") + redirectURI := r.PostForm.Get("redirect_uri") + codeVerifier := r.PostForm.Get("code_verifier") + + if code == "" || clientID == "" || redirectURI == "" || codeVerifier == "" { + writeOAuthError(w, http.StatusBadRequest, "invalid_request", + "code, client_id, redirect_uri, and code_verifier are required") + return + } + + // Look up the auth code and lock it via UPDATE ... used_at single-shot. + row := s.store.DB().QueryRowContext(r.Context(), + `SELECT client_id, redirect_uri, code_challenge, code_challenge_method, + scopes, forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at, + forgejo_user_id, forgejo_username, expires_at, used_at + FROM auth_codes WHERE code = ?`, code) + var ( + storedClientID, storedRedirectURI, storedChallenge, storedMethod, storedScopes string + fjAccess, fjRefresh, fjUsername string + fjUserID int64 + fjExpiresAt, expiresAt int64 + usedAt sql.NullInt64 + ) + if err := row.Scan(&storedClientID, &storedRedirectURI, &storedChallenge, &storedMethod, + &storedScopes, &fjAccess, &fjRefresh, &fjExpiresAt, &fjUserID, &fjUsername, + &expiresAt, &usedAt); err != nil { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code not found") + return + } + if usedAt.Valid { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code already used") + return + } + if s.now().Unix() > expiresAt { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code expired") + return + } + if storedClientID != clientID { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "client_id mismatch") + return + } + if storedRedirectURI != redirectURI { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "redirect_uri mismatch") + return + } + if storedMethod != "S256" || !verifyPKCE(codeVerifier, storedChallenge) { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "PKCE verification failed") + return + } + + // Atomically mark the code used. The WHERE clause re-checks used_at IS + // NULL so two concurrent /token requests with the same code can't both + // succeed (only one UPDATE will affect a row). + res, err := s.store.DB().ExecContext(r.Context(), + `UPDATE auth_codes SET used_at = ? WHERE code = ? AND used_at IS NULL`, + s.now().Unix(), code) + if err != nil { + s.log.Error("token: mark code used", slog.String("err", err.Error())) + writeOAuthError(w, http.StatusInternalServerError, "server_error", "") + return + } + if n, _ := res.RowsAffected(); n != 1 { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code already used") + return + } + + // Mint broker access + refresh tokens. + accessToken := secureToken(32) + refreshToken := secureToken(32) + now := s.now().Unix() + + tx, err := s.store.DB().BeginTx(r.Context(), nil) + if err != nil { + writeOAuthError(w, http.StatusInternalServerError, "server_error", "") + return + } + if _, err := tx.ExecContext(r.Context(), + `INSERT INTO access_tokens + (token_hash, client_id, forgejo_user_id, forgejo_username, scopes, + forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at, + expires_at, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + hashToken(accessToken), clientID, fjUserID, fjUsername, storedScopes, + fjAccess, fjRefresh, fjExpiresAt, + s.now().Add(AccessTokenTTL).Unix(), now); err != nil { + _ = tx.Rollback() + s.log.Error("token: insert access_token", slog.String("err", err.Error())) + writeOAuthError(w, http.StatusInternalServerError, "server_error", "") + return + } + if _, err := tx.ExecContext(r.Context(), + `INSERT INTO refresh_tokens + (token_hash, access_token_hash, client_id, expires_at, created_at) + VALUES (?, ?, ?, ?, ?)`, + hashToken(refreshToken), hashToken(accessToken), clientID, + s.now().Add(RefreshTokenTTL).Unix(), now); err != nil { + _ = tx.Rollback() + s.log.Error("token: insert refresh_token", slog.String("err", err.Error())) + writeOAuthError(w, http.StatusInternalServerError, "server_error", "") + return + } + if err := tx.Commit(); err != nil { + writeOAuthError(w, http.StatusInternalServerError, "server_error", "") + return + } + + writeJSON(w, http.StatusOK, tokenResponse{ + AccessToken: accessToken, + TokenType: "Bearer", + ExpiresIn: int(AccessTokenTTL.Seconds()), + RefreshToken: refreshToken, + Scope: storedScopes, + }) +} + +func (s *Server) tokenRefreshGrant(w http.ResponseWriter, r *http.Request) { + refreshToken := r.PostForm.Get("refresh_token") + clientID := r.PostForm.Get("client_id") + if refreshToken == "" || clientID == "" { + writeOAuthError(w, http.StatusBadRequest, "invalid_request", + "refresh_token and client_id are required") + return + } + + rtHash := hashToken(refreshToken) + row := s.store.DB().QueryRowContext(r.Context(), + `SELECT rt.access_token_hash, rt.client_id, rt.expires_at, rt.revoked_at, + at.forgejo_user_id, at.forgejo_username, at.scopes, + at.forgejo_access_token, at.forgejo_refresh_token, at.forgejo_token_expires_at + FROM refresh_tokens rt + JOIN access_tokens at ON at.token_hash = rt.access_token_hash + WHERE rt.token_hash = ?`, rtHash) + var ( + oldAccessHash, storedClientID, fjUsername, scopes string + fjAccess, fjRefresh string + fjUserID int64 + expiresAt, fjExpiresAt int64 + revokedAt sql.NullInt64 + ) + if err := row.Scan(&oldAccessHash, &storedClientID, &expiresAt, &revokedAt, + &fjUserID, &fjUsername, &scopes, &fjAccess, &fjRefresh, &fjExpiresAt); err != nil { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "refresh token not found") + return + } + if revokedAt.Valid { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "refresh token revoked") + return + } + if s.now().Unix() > expiresAt { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "refresh token expired") + return + } + if storedClientID != clientID { + writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "client_id mismatch") + return + } + + // Mint a new access token. Refresh-token rotation: also issue a new + // refresh token and revoke the old one. + newAccess := secureToken(32) + newRefresh := secureToken(32) + now := s.now().Unix() + + tx, err := s.store.DB().BeginTx(r.Context(), nil) + if err != nil { + writeOAuthError(w, http.StatusInternalServerError, "server_error", "") + return + } + if _, err := tx.ExecContext(r.Context(), + `INSERT INTO access_tokens + (token_hash, client_id, forgejo_user_id, forgejo_username, scopes, + forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at, + expires_at, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + hashToken(newAccess), clientID, fjUserID, fjUsername, scopes, + fjAccess, fjRefresh, fjExpiresAt, + s.now().Add(AccessTokenTTL).Unix(), now); err != nil { + _ = tx.Rollback() + writeOAuthError(w, http.StatusInternalServerError, "server_error", "") + return + } + if _, err := tx.ExecContext(r.Context(), + `INSERT INTO refresh_tokens + (token_hash, access_token_hash, client_id, expires_at, created_at) + VALUES (?, ?, ?, ?, ?)`, + hashToken(newRefresh), hashToken(newAccess), clientID, + s.now().Add(RefreshTokenTTL).Unix(), now); err != nil { + _ = tx.Rollback() + writeOAuthError(w, http.StatusInternalServerError, "server_error", "") + return + } + // Revoke the old refresh token (rotation per RFC 6749 §10.4). + if _, err := tx.ExecContext(r.Context(), + `UPDATE refresh_tokens SET revoked_at = ? WHERE token_hash = ?`, + now, rtHash); err != nil { + _ = tx.Rollback() + writeOAuthError(w, http.StatusInternalServerError, "server_error", "") + return + } + if err := tx.Commit(); err != nil { + writeOAuthError(w, http.StatusInternalServerError, "server_error", "") + return + } + + writeJSON(w, http.StatusOK, tokenResponse{ + AccessToken: newAccess, + TokenType: "Bearer", + ExpiresIn: int(AccessTokenTTL.Seconds()), + RefreshToken: newRefresh, + Scope: scopes, + }) +} + +// ============================================================================ +// /oauth/revoke — RFC 7009 +// ============================================================================ + +func (s *Server) handleRevoke(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + writeOAuthError(w, http.StatusBadRequest, "invalid_request", "could not parse form") + return + } + tokenStr := r.PostForm.Get("token") + hint := r.PostForm.Get("token_type_hint") // "access_token" or "refresh_token" + if tokenStr == "" { + writeOAuthError(w, http.StatusBadRequest, "invalid_request", "token is required") + return + } + + hash := hashToken(tokenStr) + now := s.now().Unix() + + // Try the hinted table first; fall back to the other. RFC 7009 says + // invalid tokens still get a 200 — clients shouldn't probe. + if hint != "refresh_token" { + _, _ = s.store.DB().ExecContext(r.Context(), + `UPDATE access_tokens SET revoked_at = ? WHERE token_hash = ? AND revoked_at IS NULL`, + now, hash) + } + if hint != "access_token" { + _, _ = s.store.DB().ExecContext(r.Context(), + `UPDATE refresh_tokens SET revoked_at = ? WHERE token_hash = ? AND revoked_at IS NULL`, + now, hash) + } + w.WriteHeader(http.StatusOK) +} diff --git a/internal/oauth/oauth_internal_test.go b/internal/oauth/oauth_internal_test.go new file mode 100644 index 0000000..2bd3055 --- /dev/null +++ b/internal/oauth/oauth_internal_test.go @@ -0,0 +1,107 @@ +package oauth + +import ( + "crypto/sha256" + "encoding/base64" + "strings" + "testing" + "time" +) + +// reapPending only runs every minute via the background ticker, so +// production tests can't drive it within a reasonable timeout. This +// internal test exercises the reaper logic directly. +func TestReapPending_RemovesExpired(t *testing.T) { + now := time.Date(2026, 4, 27, 12, 0, 0, 0, time.UTC) + s := &Server{now: func() time.Time { return now }} + + s.pending.Store("expired-1", &pendingAuth{expiresAt: now.Add(-time.Minute)}) + s.pending.Store("expired-2", &pendingAuth{expiresAt: now.Add(-time.Hour)}) + s.pending.Store("fresh", &pendingAuth{expiresAt: now.Add(time.Hour)}) + + s.reapPending() + + for _, key := range []string{"expired-1", "expired-2"} { + if _, ok := s.pending.Load(key); ok { + t.Errorf("%q should have been reaped", key) + } + } + if _, ok := s.pending.Load("fresh"); !ok { + t.Error("fresh entry was wrongly reaped") + } +} + +// secureToken's output is non-deterministic, so we test shape: hex, +// expected length, and that two consecutive calls differ. +func TestSecureToken_Shape(t *testing.T) { + a := secureToken(16) + b := secureToken(16) + if a == b { + t.Error("two secureToken calls produced identical output") + } + if len(a) != 32 { + t.Errorf("secureToken(16) length = %d, want 32", len(a)) + } + for _, c := range a { + if !strings.ContainsRune("0123456789abcdef", c) { + t.Errorf("non-hex character %q in token", c) + break + } + } +} + +func TestVerifyPKCE_Roundtrip(t *testing.T) { + verifier := "the-quick-brown-fox-jumps-over-the-lazy-dog-12345678" + sum := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(sum[:]) + + if !verifyPKCE(verifier, challenge) { + t.Error("verifyPKCE should accept matching verifier+challenge") + } + if verifyPKCE(verifier, challenge+"x") { + t.Error("verifyPKCE should reject a mutated challenge") + } + if verifyPKCE("WRONG", challenge) { + t.Error("verifyPKCE should reject a wrong verifier") + } +} + +func TestUserIDInt64(t *testing.T) { + cases := map[string]int64{ + "42": 42, + "": 0, + "abc": 0, + "123abc": 0, + "99999999999": 99999999999, + } + for in, want := range cases { + if got := userIDInt64(in); got != want { + t.Errorf("userIDInt64(%q) = %d, want %d", in, got, want) + } + } +} + +func TestValidateRedirectURI(t *testing.T) { + cases := []struct { + name string + in string + ok bool + }{ + {"https", "https://app.example.com/cb", true}, + {"http_loopback", "http://localhost:1234/cb", true}, + {"custom_scheme", "claude://oauth/cb", true}, + {"missing_scheme", "app.example.com/cb", false}, + {"unparseable", "://no-scheme", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := validateRedirectURI(tc.in) + if tc.ok && err != nil { + t.Errorf("expected ok, got %v", err) + } + if !tc.ok && err == nil { + t.Error("expected error, got nil") + } + }) + } +} diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go new file mode 100644 index 0000000..a7c07e0 --- /dev/null +++ b/internal/oauth/oauth_test.go @@ -0,0 +1,1038 @@ +package oauth_test + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "kode.naiv.no/olemd/forgejo-mcp-broker/internal/forgejo" + brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log" + "kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth" + "kode.naiv.no/olemd/forgejo-mcp-broker/internal/store" +) + +// fixture bundles the harness each test wires up: a real store, a fake +// Forgejo instance, an OAuth server, and an httptest.Server fronting it. +type fixture struct { + t *testing.T + store *store.Store + server *oauth.Server + httpServer *httptest.Server + fakeForgejo *fakeForgejo + now func() time.Time + clock *atomic.Int64 // unix seconds +} + +func newFixture(t *testing.T) *fixture { + t.Helper() + + st, err := store.Open(t.Context(), filepath.Join(t.TempDir(), "broker.db")) + if err != nil { + t.Fatalf("store: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + + fake := newFakeForgejo(t) + + fjClient, err := forgejo.NewClient(forgejo.ClientConfig{ + BaseURL: fake.server.URL, ClientID: "broker-app", ClientSecret: "broker-secret", + }) + if err != nil { + t.Fatalf("forgejo client: %v", err) + } + + clock := &atomic.Int64{} + clock.Store(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC).Unix()) + now := func() time.Time { return time.Unix(clock.Load(), 0).UTC() } + + srv, err := oauth.NewServer(oauth.Config{ + Store: st, Forgejo: fjClient, + Issuer: "https://broker.example.com", + Scopes: "read:user write:repository", + Now: now, + Log: brokerlog.Discard(), + }) + if err != nil { + t.Fatalf("oauth.NewServer: %v", err) + } + t.Cleanup(srv.Close) + + httpSrv := httptest.NewServer(srv.Handler()) + t.Cleanup(httpSrv.Close) + + return &fixture{ + t: t, + store: st, + server: srv, + httpServer: httpSrv, + fakeForgejo: fake, + now: now, + clock: clock, + } +} + +// advance moves the test clock forward by d. +func (f *fixture) advance(d time.Duration) { + f.clock.Add(int64(d.Seconds())) +} + +// registerClient calls /oauth/register with the given redirect URIs and +// returns the issued client_id. +func (f *fixture) registerClient(redirectURIs ...string) string { + f.t.Helper() + body, _ := json.Marshal(map[string]any{"redirect_uris": redirectURIs, "client_name": "Test"}) + resp, err := http.Post(f.httpServer.URL+"/oauth/register", + "application/json", strings.NewReader(string(body))) + if err != nil { + f.t.Fatalf("register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + b, _ := io.ReadAll(resp.Body) + f.t.Fatalf("register status %d: %s", resp.StatusCode, b) + } + var rr struct { + ClientID string `json:"client_id"` + } + if err := json.NewDecoder(resp.Body).Decode(&rr); err != nil { + f.t.Fatalf("decode register: %v", err) + } + return rr.ClientID +} + +// pkce returns (verifier, challenge) for a fresh PKCE pair. +func pkce(t *testing.T, verifier string) (string, string) { + t.Helper() + sum := sha256.Sum256([]byte(verifier)) + return verifier, base64.RawURLEncoding.EncodeToString(sum[:]) +} + +// noRedirectClient is an http.Client that does not follow redirects, so +// tests can inspect the 302 responses /authorize and /callback emit. +var noRedirectClient = &http.Client{ + CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }, +} + +// fakeForgejo is the same minimal Forgejo stub used in the bridge +// integration test, plus the OAuth token + userinfo endpoints needed here. +type fakeForgejo struct { + t *testing.T + server *httptest.Server + + tokenStatus int + tokenAccessToken string + tokenRefresh string + tokenError string + + userSub string + userUsername string +} + +func newFakeForgejo(t *testing.T) *fakeForgejo { + t.Helper() + f := &fakeForgejo{ + t: t, + tokenStatus: http.StatusOK, + tokenAccessToken: "fj-access", + tokenRefresh: "fj-refresh", + userSub: "42", + userUsername: "alice", + } + mux := http.NewServeMux() + mux.HandleFunc("/login/oauth/access_token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(f.tokenStatus) + if f.tokenError != "" { + _, _ = io.WriteString(w, fmt.Sprintf(`{"error":%q}`, f.tokenError)) + return + } + _, _ = io.WriteString(w, fmt.Sprintf( + `{"access_token":%q,"refresh_token":%q,"token_type":"bearer","expires_in":3600}`, + f.tokenAccessToken, f.tokenRefresh)) + }) + mux.HandleFunc("/login/oauth/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, fmt.Sprintf( + `{"sub":%q,"preferred_username":%q,"name":"Alice"}`, f.userSub, f.userUsername)) + }) + f.server = httptest.NewServer(mux) + t.Cleanup(f.server.Close) + return f +} + +// -------------------------------------------------------------------------- +// /oauth/register +// -------------------------------------------------------------------------- + +func TestRegister_HappyPath(t *testing.T) { + fx := newFixture(t) + body := `{"redirect_uris":["https://app.example.com/cb"],"client_name":"Demo"}` + resp, err := http.Post(fx.httpServer.URL+"/oauth/register", + "application/json", strings.NewReader(body)) + if err != nil { + t.Fatalf("post: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + t.Errorf("status = %d, want 201", resp.StatusCode) + } + var rr map[string]any + if err := json.NewDecoder(resp.Body).Decode(&rr); err != nil { + t.Fatalf("decode: %v", err) + } + if rr["client_id"] == "" { + t.Error("client_id empty") + } + if rr["token_endpoint_auth_method"] != "none" { + t.Errorf("auth_method = %v, want none", rr["token_endpoint_auth_method"]) + } +} + +func TestRegister_EmptyRedirectURIs(t *testing.T) { + fx := newFixture(t) + resp, _err := http.Post(fx.httpServer.URL+"/oauth/register", + "application/json", strings.NewReader(`{"redirect_uris":[]}`)) + if _err != nil { + t.Fatalf("http: %v", _err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestRegister_BadJSON(t *testing.T) { + fx := newFixture(t) + resp, _err := http.Post(fx.httpServer.URL+"/oauth/register", + "application/json", strings.NewReader(`{not json`)) + if _err != nil { + t.Fatalf("http: %v", _err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +// -------------------------------------------------------------------------- +// /oauth/authorize +// -------------------------------------------------------------------------- + +func TestAuthorize_RedirectsToForgejo(t *testing.T) { + fx := newFixture(t) + clientID := fx.registerClient("https://app.example.com/cb") + + _, challenge := pkce(t, "verifier-123") + q := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"https://app.example.com/cb"}, + "state": {"client-csrf"}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + "scope": {"read:user"}, + } + resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()) + if err != nil { + t.Fatalf("get: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusFound { + t.Fatalf("status = %d, want 302", resp.StatusCode) + } + loc := resp.Header.Get("Location") + if !strings.Contains(loc, fx.fakeForgejo.server.URL+"/login/oauth/authorize") { + t.Errorf("not redirected to Forgejo: %s", loc) + } +} + +func TestAuthorize_UnknownClient(t *testing.T) { + fx := newFixture(t) + q := url.Values{ + "response_type": {"code"}, + "client_id": {"nope"}, + "redirect_uri": {"https://app.example.com/cb"}, + } + resp, _err := http.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestAuthorize_MismatchedRedirectURI(t *testing.T) { + fx := newFixture(t) + clientID := fx.registerClient("https://app.example.com/cb") + q := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"https://evil.example.com/cb"}, + } + resp, _err := http.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestAuthorize_MissingPKCE_RedirectsWithError(t *testing.T) { + fx := newFixture(t) + clientID := fx.registerClient("https://app.example.com/cb") + q := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"https://app.example.com/cb"}, + "state": {"st"}, + } + resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + if resp.StatusCode != http.StatusFound { + t.Fatalf("status = %d, want 302", resp.StatusCode) + } + loc := resp.Header.Get("Location") + if !strings.Contains(loc, "error=invalid_request") { + t.Errorf("redirect missing error param: %s", loc) + } + if !strings.Contains(loc, "state=st") { + t.Errorf("redirect missing state echo: %s", loc) + } +} + +func TestAuthorize_NonS256_RedirectsWithError(t *testing.T) { + fx := newFixture(t) + clientID := fx.registerClient("https://app.example.com/cb") + q := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"https://app.example.com/cb"}, + "code_challenge": {"abc"}, + "code_challenge_method": {"plain"}, + } + resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + if resp.StatusCode != http.StatusFound { + t.Fatalf("status = %d, want 302", resp.StatusCode) + } + if !strings.Contains(resp.Header.Get("Location"), "error=invalid_request") { + t.Errorf("missing error: %s", resp.Header.Get("Location")) + } +} + +// -------------------------------------------------------------------------- +// /oauth/callback (also exercises the rest of the happy-path flow) +// -------------------------------------------------------------------------- + +// runFullFlow walks an entire authorize → callback → token sequence and +// returns the final token response. Used by happy-path tests and by +// downstream tests (refresh, revoke) that need a valid token to play with. +func runFullFlow(t *testing.T, fx *fixture, redirectURI, clientID, verifier string) tokenBundle { + t.Helper() + _, challenge := pkce(t, verifier) + clientState := "client-csrf-" + verifier + + // 1. /authorize: capture the redirect to Forgejo and extract its state. + authQ := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {redirectURI}, + "state": {clientState}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + "scope": {"read:user"}, + } + resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()) + if err != nil { + t.Fatalf("authorize: %v", err) + } + resp.Body.Close() + loc, _ := url.Parse(resp.Header.Get("Location")) + forgejoState := loc.Query().Get("state") + if forgejoState == "" { + t.Fatalf("authorize did not produce a forgejo state: %s", loc) + } + + // 2. /callback: simulate Forgejo redirecting back with code + state. + cbQ := url.Values{"code": {"upstream-code"}, "state": {forgejoState}} + resp, err = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + cbQ.Encode()) + if err != nil { + t.Fatalf("callback: %v", err) + } + resp.Body.Close() + cbLoc, _ := url.Parse(resp.Header.Get("Location")) + brokerCode := cbLoc.Query().Get("code") + if brokerCode == "" { + t.Fatalf("callback did not return broker code: %s", cbLoc) + } + if cbLoc.Query().Get("state") != clientState { + t.Errorf("callback dropped client_state: got %q want %q", + cbLoc.Query().Get("state"), clientState) + } + + // 3. /token: exchange broker code for access+refresh tokens. + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {brokerCode}, + "client_id": {clientID}, + "redirect_uri": {redirectURI}, + "code_verifier": {verifier}, + } + resp, err = http.PostForm(fx.httpServer.URL+"/oauth/token", form) + if err != nil { + t.Fatalf("token: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("token status = %d: %s", resp.StatusCode, b) + } + var tr tokenBundle + if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil { + t.Fatalf("decode token: %v", err) + } + return tr +} + +type tokenBundle struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +func TestEndToEnd_RegisterAuthorizeCallbackToken(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-flow-1") + + if tok.AccessToken == "" || tok.RefreshToken == "" { + t.Errorf("missing tokens: %+v", tok) + } + if tok.TokenType != "Bearer" { + t.Errorf("TokenType = %q, want Bearer", tok.TokenType) + } +} + +func TestCallback_UnknownState(t *testing.T) { + fx := newFixture(t) + q := url.Values{"code": {"x"}, "state": {"unknown"}} + resp, _err := http.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +// -------------------------------------------------------------------------- +// /oauth/token — additional grants and edge cases +// -------------------------------------------------------------------------- + +func TestToken_AuthCode_PKCEFails(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + + // Authorize with one verifier, then try to exchange with a different one. + _, challenge := pkce(t, "good-verifier") + authQ := url.Values{ + "response_type": {"code"}, + "client_id": {cid}, + "redirect_uri": {"https://app.example.com/cb"}, + "state": {"st"}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + loc, _ := url.Parse(resp.Header.Get("Location")) + forgejoState := loc.Query().Get("state") + + resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + + url.Values{"code": {"u"}, "state": {forgejoState}}.Encode()) + resp.Body.Close() + cbLoc, _ := url.Parse(resp.Header.Get("Location")) + brokerCode := cbLoc.Query().Get("code") + + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {brokerCode}, + "client_id": {cid}, + "redirect_uri": {"https://app.example.com/cb"}, + "code_verifier": {"WRONG-verifier"}, + } + resp, err := http.PostForm(fx.httpServer.URL+"/oauth/token", form) + if err != nil { + t.Fatalf("post: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "PKCE") && !strings.Contains(string(body), "invalid_grant") { + t.Errorf("expected PKCE error: %s", body) + } +} + +func TestToken_AuthCode_Reuse(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-reuse") + + // Same code can't be exchanged a second time. + // Replay the third step; we need the broker_code, but runFullFlow + // consumes it. Repeat the dance and try /token twice on the second go. + _ = tok + _, challenge := pkce(t, "verifier-reuse-2") + authQ := url.Values{ + "response_type": {"code"}, + "client_id": {cid}, + "redirect_uri": {"https://app.example.com/cb"}, + "state": {"st"}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + loc, _ := url.Parse(resp.Header.Get("Location")) + forgejoState := loc.Query().Get("state") + + resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + + url.Values{"code": {"u"}, "state": {forgejoState}}.Encode()) + resp.Body.Close() + cbLoc, _ := url.Parse(resp.Header.Get("Location")) + brokerCode := cbLoc.Query().Get("code") + + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {brokerCode}, + "client_id": {cid}, + "redirect_uri": {"https://app.example.com/cb"}, + "code_verifier": {"verifier-reuse-2"}, + } + first, _ := http.PostForm(fx.httpServer.URL+"/oauth/token", form) + first.Body.Close() + if first.StatusCode != http.StatusOK { + t.Fatalf("first /token status = %d", first.StatusCode) + } + second, err := http.PostForm(fx.httpServer.URL+"/oauth/token", form) + if err != nil { + t.Fatalf("second post: %v", err) + } + defer second.Body.Close() + if second.StatusCode != http.StatusBadRequest { + t.Errorf("second /token status = %d, want 400", second.StatusCode) + } +} + +func TestToken_Refresh_HappyPath(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + first := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-refresh-1") + + form := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {first.RefreshToken}, + "client_id": {cid}, + } + resp, err := http.PostForm(fx.httpServer.URL+"/oauth/token", form) + if err != nil { + t.Fatalf("refresh: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d: %s", resp.StatusCode, b) + } + var second tokenBundle + if err := json.NewDecoder(resp.Body).Decode(&second); err != nil { + t.Fatalf("decode: %v", err) + } + if second.AccessToken == "" || second.RefreshToken == "" { + t.Errorf("missing tokens: %+v", second) + } + if second.AccessToken == first.AccessToken { + t.Error("new access token must differ from the original") + } + if second.RefreshToken == first.RefreshToken { + t.Error("new refresh token must differ (rotation)") + } + + // The original refresh should now be invalid. + resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp2.Body.Close() + if resp2.StatusCode != http.StatusBadRequest { + t.Errorf("reusing original refresh: status = %d, want 400", resp2.StatusCode) + } +} + +func TestToken_UnsupportedGrant(t *testing.T) { + fx := newFixture(t) + form := url.Values{"grant_type": {"client_credentials"}} + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } + b, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(b), "unsupported_grant_type") { + t.Errorf("expected unsupported_grant_type: %s", b) + } +} + +// -------------------------------------------------------------------------- +// /oauth/revoke +// -------------------------------------------------------------------------- + +func TestRevoke_HappyPath(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-revoke") + + form := url.Values{"token": {tok.AccessToken}, "token_type_hint": {"access_token"}} + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/revoke", form); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("revoke status = %d, want 200", resp.StatusCode) + } + + // Refresh should still work for the unrelated refresh token here, but + // access-token validation lives in the /mcp middleware (phase 5b); + // we just verify the row got revoked by checking the DB directly. + var revokedAt int64 + row := fx.store.DB().QueryRow( + `SELECT IFNULL(revoked_at, 0) FROM access_tokens WHERE token_hash = + (SELECT MIN(token_hash) FROM access_tokens WHERE client_id = ?)`, + cid) + if err := row.Scan(&revokedAt); err != nil { + t.Fatalf("scan: %v", err) + } + if revokedAt == 0 { + t.Error("access_token row was not marked revoked") + } +} + +func TestRevoke_UnknownToken_StillReturns200(t *testing.T) { + // RFC 7009: the AS responds 200 even for unknown tokens, to prevent + // clients from probing token validity via the revoke endpoint. + fx := newFixture(t) + form := url.Values{"token": {"clearly-not-a-real-token"}} + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/revoke", form); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want 200", resp.StatusCode) + } +} + +func TestRevoke_MissingToken(t *testing.T) { + fx := newFixture(t) + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/revoke", url.Values{}); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +// -------------------------------------------------------------------------- +// Server validation +// -------------------------------------------------------------------------- + +// -------------------------------------------------------------------------- +// Callback error paths +// -------------------------------------------------------------------------- + +func TestCallback_ForgejoReportedError(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-cb-err", "client-st") + + // Forgejo redirects back with ?error=access_denied — we should + // forward that to the client, not crash. + q := url.Values{"error": {"access_denied"}, "error_description": {"user denied"}, "state": {forgejoState}} + resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + + if resp.StatusCode != http.StatusFound { + t.Fatalf("status = %d, want 302", resp.StatusCode) + } + loc := resp.Header.Get("Location") + if !strings.Contains(loc, "error=access_denied") { + t.Errorf("client redirect missing access_denied: %s", loc) + } + if !strings.Contains(loc, "state=client-st") { + t.Errorf("client redirect missing original state: %s", loc) + } +} + +func TestCallback_NoCodeNoError(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-cb-noc", "st") + + // Empty callback (no code, no error) — should redirect with server_error. + q := url.Values{"state": {forgejoState}} + resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + if !strings.Contains(resp.Header.Get("Location"), "error=server_error") { + t.Errorf("expected server_error redirect: %s", resp.Header.Get("Location")) + } +} + +func TestCallback_ForgejoExchangeFails(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-cb-fail", "st") + + // Make Forgejo's token endpoint return an error. + fx.fakeForgejo.tokenStatus = http.StatusBadRequest + fx.fakeForgejo.tokenError = "invalid_grant" + + q := url.Values{"code": {"u"}, "state": {forgejoState}} + resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + if !strings.Contains(resp.Header.Get("Location"), "error=server_error") { + t.Errorf("expected server_error: %s", resp.Header.Get("Location")) + } +} + +// -------------------------------------------------------------------------- +// Token grant error paths +// -------------------------------------------------------------------------- + +func TestToken_AuthCode_MissingFields(t *testing.T) { + fx := newFixture(t) + form := url.Values{"grant_type": {"authorization_code"}} + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestToken_AuthCode_BadCode(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {"made-up-code"}, + "client_id": {cid}, + "redirect_uri": {"https://app.example.com/cb"}, + "code_verifier": {"v"}, + } + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "invalid_grant") { + t.Errorf("expected invalid_grant: %s", body) + } +} + +func TestToken_AuthCode_ExpiredCode(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + + // Walk authorize + callback to mint a code, then advance the clock + // past the AuthCodeTTL before attempting /token. + verifier := "verifier-expire" + _, challenge := pkce(t, verifier) + authQ := url.Values{ + "response_type": {"code"}, + "client_id": {cid}, + "redirect_uri": {"https://app.example.com/cb"}, + "state": {"st"}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + loc, _ := url.Parse(resp.Header.Get("Location")) + forgejoState := loc.Query().Get("state") + + resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + + url.Values{"code": {"u"}, "state": {forgejoState}}.Encode()) + resp.Body.Close() + cbLoc, _ := url.Parse(resp.Header.Get("Location")) + brokerCode := cbLoc.Query().Get("code") + + fx.advance(oauth.AuthCodeTTL + time.Minute) + + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {brokerCode}, + "client_id": {cid}, + "redirect_uri": {"https://app.example.com/cb"}, + "code_verifier": {verifier}, + } + resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp2.Body.Close() + body, _ := io.ReadAll(resp2.Body) + if resp2.StatusCode != http.StatusBadRequest || !strings.Contains(string(body), "expired") { + t.Errorf("expected expired error, got %d: %s", resp2.StatusCode, body) + } +} + +func TestToken_AuthCode_WrongClient(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + cid2 := fx.registerClient("https://other.example.com/cb") + + // Run authorize/callback for cid, then try /token under cid2. + verifier := "verifier-wrongclient" + _, challenge := pkce(t, verifier) + authQ := url.Values{ + "response_type": {"code"}, + "client_id": {cid}, + "redirect_uri": {"https://app.example.com/cb"}, + "state": {"st"}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + loc, _ := url.Parse(resp.Header.Get("Location")) + forgejoState := loc.Query().Get("state") + resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + + url.Values{"code": {"u"}, "state": {forgejoState}}.Encode()) + resp.Body.Close() + cbLoc, _ := url.Parse(resp.Header.Get("Location")) + brokerCode := cbLoc.Query().Get("code") + + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {brokerCode}, + "client_id": {cid2}, + "redirect_uri": {"https://app.example.com/cb"}, + "code_verifier": {verifier}, + } + resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp2.Body.Close() + body, _ := io.ReadAll(resp2.Body) + if !strings.Contains(string(body), "client_id mismatch") { + t.Errorf("expected client_id mismatch: %s", body) + } +} + +func TestToken_Refresh_MissingFields(t *testing.T) { + fx := newFixture(t) + form := url.Values{"grant_type": {"refresh_token"}} + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestToken_Refresh_UnknownToken(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + form := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {"made-up-token"}, + "client_id": {cid}, + } + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestToken_Refresh_WrongClientID(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + cid2 := fx.registerClient("https://other.example.com/cb") + tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-rfwrong") + + form := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {tok.RefreshToken}, + "client_id": {cid2}, + } + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "client_id mismatch") { + t.Errorf("expected client_id mismatch: %s", body) + } +} + +func TestToken_Refresh_RevokedToken(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-rfrev") + + // Revoke the refresh token via /oauth/revoke. + revForm := url.Values{"token": {tok.RefreshToken}, "token_type_hint": {"refresh_token"}} + revResp, _ := http.PostForm(fx.httpServer.URL+"/oauth/revoke", revForm) + revResp.Body.Close() + + // Subsequent refresh attempt should fail. + form := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {tok.RefreshToken}, + "client_id": {cid}, + } + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "revoked") { + t.Errorf("expected revoked error: %s", body) + } +} + +func TestToken_Refresh_Expired(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-rfexp") + + fx.advance(oauth.RefreshTokenTTL + time.Hour) + + form := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {tok.RefreshToken}, + "client_id": {cid}, + } + resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "expired") { + t.Errorf("expected expired error: %s", body) + } +} + +func TestToken_AuthCode_WrongRedirectURI(t *testing.T) { + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + verifier := "verifier-wrongru" + _, challenge := pkce(t, verifier) + authQ := url.Values{ + "response_type": {"code"}, + "client_id": {cid}, + "redirect_uri": {"https://app.example.com/cb"}, + "state": {"st"}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } + resp.Body.Close() + loc, _ := url.Parse(resp.Header.Get("Location")) + resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + + url.Values{"code": {"u"}, "state": {loc.Query().Get("state")}}.Encode()) + resp.Body.Close() + cbLoc, _ := url.Parse(resp.Header.Get("Location")) + brokerCode := cbLoc.Query().Get("code") + + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {brokerCode}, + "client_id": {cid}, + "redirect_uri": {"https://different.example.com/cb"}, + "code_verifier": {verifier}, + } + resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } + defer resp2.Body.Close() + body, _ := io.ReadAll(resp2.Body) + if !strings.Contains(string(body), "redirect_uri mismatch") { + t.Errorf("expected redirect_uri mismatch: %s", body) + } +} + +func TestToken_NoGrantType(t *testing.T) { + fx := newFixture(t) + resp, err := http.PostForm(fx.httpServer.URL+"/oauth/token", url.Values{}) + if err != nil { + t.Fatalf("post: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +// -------------------------------------------------------------------------- +// Pending-auth reaper +// -------------------------------------------------------------------------- + +func TestReap_RemovesExpiredPending(t *testing.T) { + // Internal-API path: we can't call reapPending directly from the test + // package, but we can drive the same effect via /callback's "unknown + // state" path after time has advanced past PendingAuthTTL. + fx := newFixture(t) + cid := fx.registerClient("https://app.example.com/cb") + forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-reap", "st") + + fx.advance(oauth.PendingAuthTTL + time.Minute) + + // /callback with the (now-expired) state — pending entry's expiry is + // rejected directly. Either the reaper has cleaned it (unknown state) + // or the freshness check rejects it; both produce a 400. + q := url.Values{"code": {"u"}, "state": {forgejoState}} + resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()) + if err != nil { + t.Fatalf("get: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +// -------------------------------------------------------------------------- +// Server validation +// -------------------------------------------------------------------------- + +// startAuthorize walks just steps 1+2 of the flow (client → broker /authorize) +// and returns the forgejo state the broker stashed. Used by tests that need +// to drive /callback with arbitrary parameters afterwards. +func startAuthorize(t *testing.T, fx *fixture, clientID, redirectURI, verifier, clientState string) string { + t.Helper() + _, challenge := pkce(t, verifier) + q := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {redirectURI}, + "state": {clientState}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()) + if err != nil { + t.Fatalf("authorize: %v", err) + } + resp.Body.Close() + loc, _ := url.Parse(resp.Header.Get("Location")) + state := loc.Query().Get("state") + if state == "" { + t.Fatalf("no forgejo state in redirect: %s", loc) + } + return state +} + +func TestNewServer_ValidationErrors(t *testing.T) { + cases := []struct { + name string + cfg oauth.Config + want string + }{ + {"no_store", oauth.Config{Forgejo: &forgejo.Client{}, Issuer: "https://x"}, "Store"}, + {"no_forgejo", oauth.Config{Store: &store.Store{}, Issuer: "https://x"}, "Forgejo"}, + {"no_issuer", oauth.Config{Store: &store.Store{}, Forgejo: &forgejo.Client{}}, "Issuer"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := oauth.NewServer(tc.cfg) + if err == nil || !strings.Contains(err.Error(), tc.want) { + t.Errorf("want error containing %q, got %v", tc.want, err) + } + }) + } +}