package oauth_test import ( "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "net/url" "path/filepath" "strings" "sync/atomic" "testing" "time" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/forgejo" brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/store" ) // fixture bundles the harness each test wires up: a real store, a fake // Forgejo instance, an OAuth server, and an httptest.Server fronting it. type fixture struct { t *testing.T store *store.Store server *oauth.Server httpServer *httptest.Server fakeForgejo *fakeForgejo now func() time.Time clock *atomic.Int64 // unix seconds } func newFixture(t *testing.T) *fixture { t.Helper() st, err := store.Open(t.Context(), filepath.Join(t.TempDir(), "broker.db")) if err != nil { t.Fatalf("store: %v", err) } t.Cleanup(func() { _ = st.Close() }) fake := newFakeForgejo(t) fjClient, err := forgejo.NewClient(forgejo.ClientConfig{ BaseURL: fake.server.URL, ClientID: "broker-app", ClientSecret: "broker-secret", }) if err != nil { t.Fatalf("forgejo client: %v", err) } clock := &atomic.Int64{} clock.Store(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC).Unix()) now := func() time.Time { return time.Unix(clock.Load(), 0).UTC() } srv, err := oauth.NewServer(oauth.Config{ Store: st, Forgejo: fjClient, Issuer: "https://broker.example.com", Scopes: "read:user write:repository", Now: now, Log: brokerlog.Discard(), }) if err != nil { t.Fatalf("oauth.NewServer: %v", err) } t.Cleanup(srv.Close) httpSrv := httptest.NewServer(srv.Handler()) t.Cleanup(httpSrv.Close) return &fixture{ t: t, store: st, server: srv, httpServer: httpSrv, fakeForgejo: fake, now: now, clock: clock, } } // advance moves the test clock forward by d. func (f *fixture) advance(d time.Duration) { f.clock.Add(int64(d.Seconds())) } // registerClient calls /oauth/register with the given redirect URIs and // returns the issued client_id. func (f *fixture) registerClient(redirectURIs ...string) string { f.t.Helper() body, _ := json.Marshal(map[string]any{"redirect_uris": redirectURIs, "client_name": "Test"}) resp, err := http.Post(f.httpServer.URL+"/oauth/register", "application/json", strings.NewReader(string(body))) if err != nil { f.t.Fatalf("register: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusCreated { b, _ := io.ReadAll(resp.Body) f.t.Fatalf("register status %d: %s", resp.StatusCode, b) } var rr struct { ClientID string `json:"client_id"` } if err := json.NewDecoder(resp.Body).Decode(&rr); err != nil { f.t.Fatalf("decode register: %v", err) } return rr.ClientID } // pkce returns (verifier, challenge) for a fresh PKCE pair. func pkce(t *testing.T, verifier string) (string, string) { t.Helper() sum := sha256.Sum256([]byte(verifier)) return verifier, base64.RawURLEncoding.EncodeToString(sum[:]) } // noRedirectClient is an http.Client that does not follow redirects, so // tests can inspect the 302 responses /authorize and /callback emit. var noRedirectClient = &http.Client{ CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }, } // fakeForgejo is the same minimal Forgejo stub used in the bridge // integration test, plus the OAuth token + userinfo endpoints needed here. type fakeForgejo struct { t *testing.T server *httptest.Server tokenStatus int tokenAccessToken string tokenRefresh string tokenError string userSub string userUsername string } func newFakeForgejo(t *testing.T) *fakeForgejo { t.Helper() f := &fakeForgejo{ t: t, tokenStatus: http.StatusOK, tokenAccessToken: "fj-access", tokenRefresh: "fj-refresh", userSub: "42", userUsername: "alice", } mux := http.NewServeMux() mux.HandleFunc("/login/oauth/access_token", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(f.tokenStatus) if f.tokenError != "" { _, _ = io.WriteString(w, fmt.Sprintf(`{"error":%q}`, f.tokenError)) return } _, _ = io.WriteString(w, fmt.Sprintf( `{"access_token":%q,"refresh_token":%q,"token_type":"bearer","expires_in":3600}`, f.tokenAccessToken, f.tokenRefresh)) }) mux.HandleFunc("/login/oauth/userinfo", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, fmt.Sprintf( `{"sub":%q,"preferred_username":%q,"name":"Alice"}`, f.userSub, f.userUsername)) }) f.server = httptest.NewServer(mux) t.Cleanup(f.server.Close) return f } // -------------------------------------------------------------------------- // /oauth/register // -------------------------------------------------------------------------- func TestRegister_HappyPath(t *testing.T) { fx := newFixture(t) body := `{"redirect_uris":["https://app.example.com/cb"],"client_name":"Demo"}` resp, err := http.Post(fx.httpServer.URL+"/oauth/register", "application/json", strings.NewReader(body)) if err != nil { t.Fatalf("post: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusCreated { t.Errorf("status = %d, want 201", resp.StatusCode) } var rr map[string]any if err := json.NewDecoder(resp.Body).Decode(&rr); err != nil { t.Fatalf("decode: %v", err) } if rr["client_id"] == "" { t.Error("client_id empty") } if rr["token_endpoint_auth_method"] != "none" { t.Errorf("auth_method = %v, want none", rr["token_endpoint_auth_method"]) } } func TestRegister_EmptyRedirectURIs(t *testing.T) { fx := newFixture(t) resp, _err := http.Post(fx.httpServer.URL+"/oauth/register", "application/json", strings.NewReader(`{"redirect_uris":[]}`)) if _err != nil { t.Fatalf("http: %v", _err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } func TestRegister_BadJSON(t *testing.T) { fx := newFixture(t) resp, _err := http.Post(fx.httpServer.URL+"/oauth/register", "application/json", strings.NewReader(`{not json`)) if _err != nil { t.Fatalf("http: %v", _err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } // -------------------------------------------------------------------------- // /oauth/authorize // -------------------------------------------------------------------------- func TestAuthorize_RedirectsToForgejo(t *testing.T) { fx := newFixture(t) clientID := fx.registerClient("https://app.example.com/cb") _, challenge := pkce(t, "verifier-123") q := url.Values{ "response_type": {"code"}, "client_id": {clientID}, "redirect_uri": {"https://app.example.com/cb"}, "state": {"client-csrf"}, "code_challenge": {challenge}, "code_challenge_method": {"S256"}, "scope": {"read:user"}, } resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()) if err != nil { t.Fatalf("get: %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusFound { t.Fatalf("status = %d, want 302", resp.StatusCode) } loc := resp.Header.Get("Location") if !strings.Contains(loc, fx.fakeForgejo.server.URL+"/login/oauth/authorize") { t.Errorf("not redirected to Forgejo: %s", loc) } } func TestAuthorize_UnknownClient(t *testing.T) { fx := newFixture(t) q := url.Values{ "response_type": {"code"}, "client_id": {"nope"}, "redirect_uri": {"https://app.example.com/cb"}, } resp, _err := http.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } func TestAuthorize_MismatchedRedirectURI(t *testing.T) { fx := newFixture(t) clientID := fx.registerClient("https://app.example.com/cb") q := url.Values{ "response_type": {"code"}, "client_id": {clientID}, "redirect_uri": {"https://evil.example.com/cb"}, } resp, _err := http.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } func TestAuthorize_MissingPKCE_RedirectsWithError(t *testing.T) { fx := newFixture(t) clientID := fx.registerClient("https://app.example.com/cb") q := url.Values{ "response_type": {"code"}, "client_id": {clientID}, "redirect_uri": {"https://app.example.com/cb"}, "state": {"st"}, } resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } if resp.StatusCode != http.StatusFound { t.Fatalf("status = %d, want 302", resp.StatusCode) } loc := resp.Header.Get("Location") if !strings.Contains(loc, "error=invalid_request") { t.Errorf("redirect missing error param: %s", loc) } if !strings.Contains(loc, "state=st") { t.Errorf("redirect missing state echo: %s", loc) } } func TestAuthorize_NonS256_RedirectsWithError(t *testing.T) { fx := newFixture(t) clientID := fx.registerClient("https://app.example.com/cb") q := url.Values{ "response_type": {"code"}, "client_id": {clientID}, "redirect_uri": {"https://app.example.com/cb"}, "code_challenge": {"abc"}, "code_challenge_method": {"plain"}, } resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } if resp.StatusCode != http.StatusFound { t.Fatalf("status = %d, want 302", resp.StatusCode) } if !strings.Contains(resp.Header.Get("Location"), "error=invalid_request") { t.Errorf("missing error: %s", resp.Header.Get("Location")) } } // -------------------------------------------------------------------------- // /oauth/callback (also exercises the rest of the happy-path flow) // -------------------------------------------------------------------------- // runFullFlow walks an entire authorize → callback → token sequence and // returns the final token response. Used by happy-path tests and by // downstream tests (refresh, revoke) that need a valid token to play with. func runFullFlow(t *testing.T, fx *fixture, redirectURI, clientID, verifier string) tokenBundle { t.Helper() _, challenge := pkce(t, verifier) clientState := "client-csrf-" + verifier // 1. /authorize: capture the redirect to Forgejo and extract its state. authQ := url.Values{ "response_type": {"code"}, "client_id": {clientID}, "redirect_uri": {redirectURI}, "state": {clientState}, "code_challenge": {challenge}, "code_challenge_method": {"S256"}, "scope": {"read:user"}, } resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()) if err != nil { t.Fatalf("authorize: %v", err) } resp.Body.Close() loc, _ := url.Parse(resp.Header.Get("Location")) forgejoState := loc.Query().Get("state") if forgejoState == "" { t.Fatalf("authorize did not produce a forgejo state: %s", loc) } // 2. /callback: simulate Forgejo redirecting back with code + state. cbQ := url.Values{"code": {"upstream-code"}, "state": {forgejoState}} resp, err = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + cbQ.Encode()) if err != nil { t.Fatalf("callback: %v", err) } resp.Body.Close() cbLoc, _ := url.Parse(resp.Header.Get("Location")) brokerCode := cbLoc.Query().Get("code") if brokerCode == "" { t.Fatalf("callback did not return broker code: %s", cbLoc) } if cbLoc.Query().Get("state") != clientState { t.Errorf("callback dropped client_state: got %q want %q", cbLoc.Query().Get("state"), clientState) } // 3. /token: exchange broker code for access+refresh tokens. form := url.Values{ "grant_type": {"authorization_code"}, "code": {brokerCode}, "client_id": {clientID}, "redirect_uri": {redirectURI}, "code_verifier": {verifier}, } resp, err = http.PostForm(fx.httpServer.URL+"/oauth/token", form) if err != nil { t.Fatalf("token: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { b, _ := io.ReadAll(resp.Body) t.Fatalf("token status = %d: %s", resp.StatusCode, b) } var tr tokenBundle if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil { t.Fatalf("decode token: %v", err) } return tr } type tokenBundle struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Scope string `json:"scope"` } func TestEndToEnd_RegisterAuthorizeCallbackToken(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-flow-1") if tok.AccessToken == "" || tok.RefreshToken == "" { t.Errorf("missing tokens: %+v", tok) } if tok.TokenType != "Bearer" { t.Errorf("TokenType = %q, want Bearer", tok.TokenType) } } func TestCallback_UnknownState(t *testing.T) { fx := newFixture(t) q := url.Values{"code": {"x"}, "state": {"unknown"}} resp, _err := http.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } // -------------------------------------------------------------------------- // /oauth/token — additional grants and edge cases // -------------------------------------------------------------------------- func TestToken_AuthCode_PKCEFails(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") // Authorize with one verifier, then try to exchange with a different one. _, challenge := pkce(t, "good-verifier") authQ := url.Values{ "response_type": {"code"}, "client_id": {cid}, "redirect_uri": {"https://app.example.com/cb"}, "state": {"st"}, "code_challenge": {challenge}, "code_challenge_method": {"S256"}, } resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() loc, _ := url.Parse(resp.Header.Get("Location")) forgejoState := loc.Query().Get("state") resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + url.Values{"code": {"u"}, "state": {forgejoState}}.Encode()) resp.Body.Close() cbLoc, _ := url.Parse(resp.Header.Get("Location")) brokerCode := cbLoc.Query().Get("code") form := url.Values{ "grant_type": {"authorization_code"}, "code": {brokerCode}, "client_id": {cid}, "redirect_uri": {"https://app.example.com/cb"}, "code_verifier": {"WRONG-verifier"}, } resp, err := http.PostForm(fx.httpServer.URL+"/oauth/token", form) if err != nil { t.Fatalf("post: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } body, _ := io.ReadAll(resp.Body) if !strings.Contains(string(body), "PKCE") && !strings.Contains(string(body), "invalid_grant") { t.Errorf("expected PKCE error: %s", body) } } func TestToken_AuthCode_Reuse(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-reuse") // Same code can't be exchanged a second time. // Replay the third step; we need the broker_code, but runFullFlow // consumes it. Repeat the dance and try /token twice on the second go. _ = tok _, challenge := pkce(t, "verifier-reuse-2") authQ := url.Values{ "response_type": {"code"}, "client_id": {cid}, "redirect_uri": {"https://app.example.com/cb"}, "state": {"st"}, "code_challenge": {challenge}, "code_challenge_method": {"S256"}, } resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() loc, _ := url.Parse(resp.Header.Get("Location")) forgejoState := loc.Query().Get("state") resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + url.Values{"code": {"u"}, "state": {forgejoState}}.Encode()) resp.Body.Close() cbLoc, _ := url.Parse(resp.Header.Get("Location")) brokerCode := cbLoc.Query().Get("code") form := url.Values{ "grant_type": {"authorization_code"}, "code": {brokerCode}, "client_id": {cid}, "redirect_uri": {"https://app.example.com/cb"}, "code_verifier": {"verifier-reuse-2"}, } first, _ := http.PostForm(fx.httpServer.URL+"/oauth/token", form) first.Body.Close() if first.StatusCode != http.StatusOK { t.Fatalf("first /token status = %d", first.StatusCode) } second, err := http.PostForm(fx.httpServer.URL+"/oauth/token", form) if err != nil { t.Fatalf("second post: %v", err) } defer second.Body.Close() if second.StatusCode != http.StatusBadRequest { t.Errorf("second /token status = %d, want 400", second.StatusCode) } } func TestToken_Refresh_HappyPath(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") first := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-refresh-1") form := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {first.RefreshToken}, "client_id": {cid}, } resp, err := http.PostForm(fx.httpServer.URL+"/oauth/token", form) if err != nil { t.Fatalf("refresh: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { b, _ := io.ReadAll(resp.Body) t.Fatalf("status = %d: %s", resp.StatusCode, b) } var second tokenBundle if err := json.NewDecoder(resp.Body).Decode(&second); err != nil { t.Fatalf("decode: %v", err) } if second.AccessToken == "" || second.RefreshToken == "" { t.Errorf("missing tokens: %+v", second) } if second.AccessToken == first.AccessToken { t.Error("new access token must differ from the original") } if second.RefreshToken == first.RefreshToken { t.Error("new refresh token must differ (rotation)") } // The original refresh should now be invalid. resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } defer resp2.Body.Close() if resp2.StatusCode != http.StatusBadRequest { t.Errorf("reusing original refresh: status = %d, want 400", resp2.StatusCode) } } func TestToken_UnsupportedGrant(t *testing.T) { fx := newFixture(t) form := url.Values{"grant_type": {"client_credentials"}} resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } b, _ := io.ReadAll(resp.Body) if !strings.Contains(string(b), "unsupported_grant_type") { t.Errorf("expected unsupported_grant_type: %s", b) } } // -------------------------------------------------------------------------- // /oauth/revoke // -------------------------------------------------------------------------- func TestRevoke_HappyPath(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-revoke") form := url.Values{"token": {tok.AccessToken}, "token_type_hint": {"access_token"}} resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/revoke", form); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("revoke status = %d, want 200", resp.StatusCode) } // Refresh should still work for the unrelated refresh token here, but // access-token validation lives in the /mcp middleware (phase 5b); // we just verify the row got revoked by checking the DB directly. var revokedAt int64 row := fx.store.DB().QueryRow( `SELECT IFNULL(revoked_at, 0) FROM access_tokens WHERE token_hash = (SELECT MIN(token_hash) FROM access_tokens WHERE client_id = ?)`, cid) if err := row.Scan(&revokedAt); err != nil { t.Fatalf("scan: %v", err) } if revokedAt == 0 { t.Error("access_token row was not marked revoked") } } func TestRevoke_UnknownToken_StillReturns200(t *testing.T) { // RFC 7009: the AS responds 200 even for unknown tokens, to prevent // clients from probing token validity via the revoke endpoint. fx := newFixture(t) form := url.Values{"token": {"clearly-not-a-real-token"}} resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/revoke", form); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want 200", resp.StatusCode) } } func TestRevoke_MissingToken(t *testing.T) { fx := newFixture(t) resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/revoke", url.Values{}); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } // -------------------------------------------------------------------------- // Callback error paths // -------------------------------------------------------------------------- func TestCallback_ForgejoReportedError(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-cb-err", "client-st") // Forgejo redirects back with ?error=access_denied — we should // forward that to the client, not crash. q := url.Values{"error": {"access_denied"}, "error_description": {"user denied"}, "state": {forgejoState}} resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() if resp.StatusCode != http.StatusFound { t.Fatalf("status = %d, want 302", resp.StatusCode) } loc := resp.Header.Get("Location") if !strings.Contains(loc, "error=access_denied") { t.Errorf("client redirect missing access_denied: %s", loc) } if !strings.Contains(loc, "state=client-st") { t.Errorf("client redirect missing original state: %s", loc) } } func TestCallback_NoCodeNoError(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-cb-noc", "st") // Empty callback (no code, no error) — should redirect with server_error. q := url.Values{"state": {forgejoState}} resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() if !strings.Contains(resp.Header.Get("Location"), "error=server_error") { t.Errorf("expected server_error redirect: %s", resp.Header.Get("Location")) } } func TestCallback_ForgejoExchangeFails(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-cb-fail", "st") // Make Forgejo's token endpoint return an error. fx.fakeForgejo.tokenStatus = http.StatusBadRequest fx.fakeForgejo.tokenError = "invalid_grant" q := url.Values{"code": {"u"}, "state": {forgejoState}} resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() if !strings.Contains(resp.Header.Get("Location"), "error=server_error") { t.Errorf("expected server_error: %s", resp.Header.Get("Location")) } } // -------------------------------------------------------------------------- // Token grant error paths // -------------------------------------------------------------------------- func TestToken_AuthCode_MissingFields(t *testing.T) { fx := newFixture(t) form := url.Values{"grant_type": {"authorization_code"}} resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } func TestToken_AuthCode_BadCode(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") form := url.Values{ "grant_type": {"authorization_code"}, "code": {"made-up-code"}, "client_id": {cid}, "redirect_uri": {"https://app.example.com/cb"}, "code_verifier": {"v"}, } resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } body, _ := io.ReadAll(resp.Body) if !strings.Contains(string(body), "invalid_grant") { t.Errorf("expected invalid_grant: %s", body) } } func TestToken_AuthCode_ExpiredCode(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") // Walk authorize + callback to mint a code, then advance the clock // past the AuthCodeTTL before attempting /token. verifier := "verifier-expire" _, challenge := pkce(t, verifier) authQ := url.Values{ "response_type": {"code"}, "client_id": {cid}, "redirect_uri": {"https://app.example.com/cb"}, "state": {"st"}, "code_challenge": {challenge}, "code_challenge_method": {"S256"}, } resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() loc, _ := url.Parse(resp.Header.Get("Location")) forgejoState := loc.Query().Get("state") resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + url.Values{"code": {"u"}, "state": {forgejoState}}.Encode()) resp.Body.Close() cbLoc, _ := url.Parse(resp.Header.Get("Location")) brokerCode := cbLoc.Query().Get("code") fx.advance(oauth.AuthCodeTTL + time.Minute) form := url.Values{ "grant_type": {"authorization_code"}, "code": {brokerCode}, "client_id": {cid}, "redirect_uri": {"https://app.example.com/cb"}, "code_verifier": {verifier}, } resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } defer resp2.Body.Close() body, _ := io.ReadAll(resp2.Body) if resp2.StatusCode != http.StatusBadRequest || !strings.Contains(string(body), "expired") { t.Errorf("expected expired error, got %d: %s", resp2.StatusCode, body) } } func TestToken_AuthCode_WrongClient(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") cid2 := fx.registerClient("https://other.example.com/cb") // Run authorize/callback for cid, then try /token under cid2. verifier := "verifier-wrongclient" _, challenge := pkce(t, verifier) authQ := url.Values{ "response_type": {"code"}, "client_id": {cid}, "redirect_uri": {"https://app.example.com/cb"}, "state": {"st"}, "code_challenge": {challenge}, "code_challenge_method": {"S256"}, } resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() loc, _ := url.Parse(resp.Header.Get("Location")) forgejoState := loc.Query().Get("state") resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + url.Values{"code": {"u"}, "state": {forgejoState}}.Encode()) resp.Body.Close() cbLoc, _ := url.Parse(resp.Header.Get("Location")) brokerCode := cbLoc.Query().Get("code") form := url.Values{ "grant_type": {"authorization_code"}, "code": {brokerCode}, "client_id": {cid2}, "redirect_uri": {"https://app.example.com/cb"}, "code_verifier": {verifier}, } resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } defer resp2.Body.Close() body, _ := io.ReadAll(resp2.Body) if !strings.Contains(string(body), "client_id mismatch") { t.Errorf("expected client_id mismatch: %s", body) } } func TestToken_Refresh_MissingFields(t *testing.T) { fx := newFixture(t) form := url.Values{"grant_type": {"refresh_token"}} resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } func TestToken_Refresh_UnknownToken(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") form := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {"made-up-token"}, "client_id": {cid}, } resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } func TestToken_Refresh_WrongClientID(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") cid2 := fx.registerClient("https://other.example.com/cb") tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-rfwrong") form := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {tok.RefreshToken}, "client_id": {cid2}, } resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if !strings.Contains(string(body), "client_id mismatch") { t.Errorf("expected client_id mismatch: %s", body) } } func TestToken_Refresh_RevokedToken(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-rfrev") // Revoke the refresh token via /oauth/revoke. revForm := url.Values{"token": {tok.RefreshToken}, "token_type_hint": {"refresh_token"}} revResp, _ := http.PostForm(fx.httpServer.URL+"/oauth/revoke", revForm) revResp.Body.Close() // Subsequent refresh attempt should fail. form := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {tok.RefreshToken}, "client_id": {cid}, } resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if !strings.Contains(string(body), "revoked") { t.Errorf("expected revoked error: %s", body) } } func TestToken_Refresh_Expired(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-rfexp") fx.advance(oauth.RefreshTokenTTL + time.Hour) form := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {tok.RefreshToken}, "client_id": {cid}, } resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if !strings.Contains(string(body), "expired") { t.Errorf("expected expired error: %s", body) } } func TestToken_AuthCode_WrongRedirectURI(t *testing.T) { fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") verifier := "verifier-wrongru" _, challenge := pkce(t, verifier) authQ := url.Values{ "response_type": {"code"}, "client_id": {cid}, "redirect_uri": {"https://app.example.com/cb"}, "state": {"st"}, "code_challenge": {challenge}, "code_challenge_method": {"S256"}, } resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) } resp.Body.Close() loc, _ := url.Parse(resp.Header.Get("Location")) resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + url.Values{"code": {"u"}, "state": {loc.Query().Get("state")}}.Encode()) resp.Body.Close() cbLoc, _ := url.Parse(resp.Header.Get("Location")) brokerCode := cbLoc.Query().Get("code") form := url.Values{ "grant_type": {"authorization_code"}, "code": {brokerCode}, "client_id": {cid}, "redirect_uri": {"https://different.example.com/cb"}, "code_verifier": {verifier}, } resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) } defer resp2.Body.Close() body, _ := io.ReadAll(resp2.Body) if !strings.Contains(string(body), "redirect_uri mismatch") { t.Errorf("expected redirect_uri mismatch: %s", body) } } func TestToken_NoGrantType(t *testing.T) { fx := newFixture(t) resp, err := http.PostForm(fx.httpServer.URL+"/oauth/token", url.Values{}) if err != nil { t.Fatalf("post: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } // -------------------------------------------------------------------------- // Pending-auth reaper // -------------------------------------------------------------------------- func TestReap_RemovesExpiredPending(t *testing.T) { // Internal-API path: we can't call reapPending directly from the test // package, but we can drive the same effect via /callback's "unknown // state" path after time has advanced past PendingAuthTTL. fx := newFixture(t) cid := fx.registerClient("https://app.example.com/cb") forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-reap", "st") fx.advance(oauth.PendingAuthTTL + time.Minute) // /callback with the (now-expired) state — pending entry's expiry is // rejected directly. Either the reaper has cleaned it (unknown state) // or the freshness check rejects it; both produce a 400. q := url.Values{"code": {"u"}, "state": {forgejoState}} resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()) if err != nil { t.Fatalf("get: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("status = %d, want 400", resp.StatusCode) } } // -------------------------------------------------------------------------- // Discovery (.well-known) // -------------------------------------------------------------------------- func TestDiscovery_AuthorizationServerMetadata(t *testing.T) { fx := newFixture(t) resp, err := http.Get(fx.httpServer.URL + "/.well-known/oauth-authorization-server") if err != nil { t.Fatalf("get: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want 200", resp.StatusCode) } if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/json") { t.Errorf("Content-Type = %q, want application/json", ct) } var md map[string]any if err := json.NewDecoder(resp.Body).Decode(&md); err != nil { t.Fatalf("decode: %v", err) } // Issuer MUST come from cfg.Issuer, not from the request URL. if md["issuer"] != "https://broker.example.com" { t.Errorf("issuer = %v, want https://broker.example.com (config)", md["issuer"]) } for _, k := range []string{"authorization_endpoint", "token_endpoint", "registration_endpoint", "revocation_endpoint"} { v, _ := md[k].(string) if !strings.HasPrefix(v, "https://broker.example.com/") { t.Errorf("%s = %q, want issuer-rooted URL", k, v) } } wantContains := map[string][]string{ "code_challenge_methods_supported": {"S256"}, "grant_types_supported": {"authorization_code", "refresh_token"}, "response_types_supported": {"code"}, "token_endpoint_auth_methods_supported": {"none"}, } for field, vals := range wantContains { got, _ := md[field].([]any) if len(got) == 0 { t.Errorf("%s missing or empty", field) continue } for _, want := range vals { found := false for _, g := range got { if g == want { found = true break } } if !found { t.Errorf("%s does not include %q (got %v)", field, want, got) } } } scopes, _ := md["scopes_supported"].([]any) if len(scopes) != 2 { t.Errorf("scopes_supported = %v, want 2 entries", scopes) } } func TestDiscovery_ProtectedResourceMetadata(t *testing.T) { fx := newFixture(t) resp, err := http.Get(fx.httpServer.URL + "/.well-known/oauth-protected-resource") if err != nil { t.Fatalf("get: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want 200", resp.StatusCode) } var md map[string]any if err := json.NewDecoder(resp.Body).Decode(&md); err != nil { t.Fatalf("decode: %v", err) } if md["resource"] != "https://broker.example.com/mcp" { t.Errorf("resource = %v, want issuer-rooted /mcp", md["resource"]) } servers, _ := md["authorization_servers"].([]any) if len(servers) != 1 || servers[0] != "https://broker.example.com" { t.Errorf("authorization_servers = %v, want [config issuer]", servers) } bearer, _ := md["bearer_methods_supported"].([]any) if len(bearer) == 0 || bearer[0] != "header" { t.Errorf("bearer_methods_supported = %v, want [\"header\"]", bearer) } } func TestDiscovery_IssuerIgnoresHostHeader(t *testing.T) { // Crafting a malicious Host header must not leak that host into the // discovery document. Defense against the metadata-spoofing class of // OAuth attack starts here. fx := newFixture(t) req, err := http.NewRequest(http.MethodGet, fx.httpServer.URL+"/.well-known/oauth-authorization-server", nil) if err != nil { t.Fatalf("new request: %v", err) } req.Host = "evil.example.com" resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if strings.Contains(string(body), "evil.example.com") { t.Errorf("discovery doc leaked attacker-supplied Host: %s", body) } } // -------------------------------------------------------------------------- // Helpers used by /callback flow tests // -------------------------------------------------------------------------- // startAuthorize walks just steps 1+2 of the flow (client → broker /authorize) // and returns the forgejo state the broker stashed. Used by tests that need // to drive /callback with arbitrary parameters afterwards. func startAuthorize(t *testing.T, fx *fixture, clientID, redirectURI, verifier, clientState string) string { t.Helper() _, challenge := pkce(t, verifier) q := url.Values{ "response_type": {"code"}, "client_id": {clientID}, "redirect_uri": {redirectURI}, "state": {clientState}, "code_challenge": {challenge}, "code_challenge_method": {"S256"}, } resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()) if err != nil { t.Fatalf("authorize: %v", err) } resp.Body.Close() loc, _ := url.Parse(resp.Header.Get("Location")) state := loc.Query().Get("state") if state == "" { t.Fatalf("no forgejo state in redirect: %s", loc) } return state } func TestNewServer_ValidationErrors(t *testing.T) { cases := []struct { name string cfg oauth.Config want string }{ {"no_store", oauth.Config{Forgejo: &forgejo.Client{}, Issuer: "https://x"}, "Store"}, {"no_forgejo", oauth.Config{Store: &store.Store{}, Issuer: "https://x"}, "Forgejo"}, {"no_issuer", oauth.Config{Store: &store.Store{}, Forgejo: &forgejo.Client{}}, "Issuer"}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { _, err := oauth.NewServer(tc.cfg) if err == nil || !strings.Contains(err.Error(), tc.want) { t.Errorf("want error containing %q, got %v", tc.want, err) } }) } }