Adds RFC 8414 (oauth-authorization-server) and RFC 9728 (oauth- protected-resource) metadata documents. Both URLs are derived from cfg.Issuer at construction time, never from inbound request headers. Test TestDiscovery_IssuerIgnoresHostHeader explicitly probes this — a malicious Host: evil.example.com value must not leak into the published metadata. Defense against the OAuth metadata-spoofing class starts at the discovery layer. Capabilities published reflect the actual OAuth surface: - response_types_supported = ["code"] - grant_types_supported = ["authorization_code", "refresh_token"] - code_challenge_methods_supported = ["S256"] (PKCE only, no plain) - token_endpoint_auth_methods_supported = ["none"] (PKCE-only public clients) Protected-resource metadata advertises /mcp as the resource; phase 5 will mount the gated MCP endpoint there. Closes forgejo-mcp-broker-b2o. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1150 lines
39 KiB
Go
1150 lines
39 KiB
Go
package oauth_test
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/forgejo"
|
|
brokerlog "kode.naiv.no/olemd/forgejo-mcp-broker/internal/log"
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/store"
|
|
)
|
|
|
|
// fixture bundles the harness each test wires up: a real store, a fake
|
|
// Forgejo instance, an OAuth server, and an httptest.Server fronting it.
|
|
type fixture struct {
|
|
t *testing.T
|
|
store *store.Store
|
|
server *oauth.Server
|
|
httpServer *httptest.Server
|
|
fakeForgejo *fakeForgejo
|
|
now func() time.Time
|
|
clock *atomic.Int64 // unix seconds
|
|
}
|
|
|
|
func newFixture(t *testing.T) *fixture {
|
|
t.Helper()
|
|
|
|
st, err := store.Open(t.Context(), filepath.Join(t.TempDir(), "broker.db"))
|
|
if err != nil {
|
|
t.Fatalf("store: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = st.Close() })
|
|
|
|
fake := newFakeForgejo(t)
|
|
|
|
fjClient, err := forgejo.NewClient(forgejo.ClientConfig{
|
|
BaseURL: fake.server.URL, ClientID: "broker-app", ClientSecret: "broker-secret",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("forgejo client: %v", err)
|
|
}
|
|
|
|
clock := &atomic.Int64{}
|
|
clock.Store(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC).Unix())
|
|
now := func() time.Time { return time.Unix(clock.Load(), 0).UTC() }
|
|
|
|
srv, err := oauth.NewServer(oauth.Config{
|
|
Store: st, Forgejo: fjClient,
|
|
Issuer: "https://broker.example.com",
|
|
Scopes: "read:user write:repository",
|
|
Now: now,
|
|
Log: brokerlog.Discard(),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("oauth.NewServer: %v", err)
|
|
}
|
|
t.Cleanup(srv.Close)
|
|
|
|
httpSrv := httptest.NewServer(srv.Handler())
|
|
t.Cleanup(httpSrv.Close)
|
|
|
|
return &fixture{
|
|
t: t,
|
|
store: st,
|
|
server: srv,
|
|
httpServer: httpSrv,
|
|
fakeForgejo: fake,
|
|
now: now,
|
|
clock: clock,
|
|
}
|
|
}
|
|
|
|
// advance moves the test clock forward by d.
|
|
func (f *fixture) advance(d time.Duration) {
|
|
f.clock.Add(int64(d.Seconds()))
|
|
}
|
|
|
|
// registerClient calls /oauth/register with the given redirect URIs and
|
|
// returns the issued client_id.
|
|
func (f *fixture) registerClient(redirectURIs ...string) string {
|
|
f.t.Helper()
|
|
body, _ := json.Marshal(map[string]any{"redirect_uris": redirectURIs, "client_name": "Test"})
|
|
resp, err := http.Post(f.httpServer.URL+"/oauth/register",
|
|
"application/json", strings.NewReader(string(body)))
|
|
if err != nil {
|
|
f.t.Fatalf("register: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusCreated {
|
|
b, _ := io.ReadAll(resp.Body)
|
|
f.t.Fatalf("register status %d: %s", resp.StatusCode, b)
|
|
}
|
|
var rr struct {
|
|
ClientID string `json:"client_id"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&rr); err != nil {
|
|
f.t.Fatalf("decode register: %v", err)
|
|
}
|
|
return rr.ClientID
|
|
}
|
|
|
|
// pkce returns (verifier, challenge) for a fresh PKCE pair.
|
|
func pkce(t *testing.T, verifier string) (string, string) {
|
|
t.Helper()
|
|
sum := sha256.Sum256([]byte(verifier))
|
|
return verifier, base64.RawURLEncoding.EncodeToString(sum[:])
|
|
}
|
|
|
|
// noRedirectClient is an http.Client that does not follow redirects, so
|
|
// tests can inspect the 302 responses /authorize and /callback emit.
|
|
var noRedirectClient = &http.Client{
|
|
CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse },
|
|
}
|
|
|
|
// fakeForgejo is the same minimal Forgejo stub used in the bridge
|
|
// integration test, plus the OAuth token + userinfo endpoints needed here.
|
|
type fakeForgejo struct {
|
|
t *testing.T
|
|
server *httptest.Server
|
|
|
|
tokenStatus int
|
|
tokenAccessToken string
|
|
tokenRefresh string
|
|
tokenError string
|
|
|
|
userSub string
|
|
userUsername string
|
|
}
|
|
|
|
func newFakeForgejo(t *testing.T) *fakeForgejo {
|
|
t.Helper()
|
|
f := &fakeForgejo{
|
|
t: t,
|
|
tokenStatus: http.StatusOK,
|
|
tokenAccessToken: "fj-access",
|
|
tokenRefresh: "fj-refresh",
|
|
userSub: "42",
|
|
userUsername: "alice",
|
|
}
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/login/oauth/access_token", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(f.tokenStatus)
|
|
if f.tokenError != "" {
|
|
_, _ = io.WriteString(w, fmt.Sprintf(`{"error":%q}`, f.tokenError))
|
|
return
|
|
}
|
|
_, _ = io.WriteString(w, fmt.Sprintf(
|
|
`{"access_token":%q,"refresh_token":%q,"token_type":"bearer","expires_in":3600}`,
|
|
f.tokenAccessToken, f.tokenRefresh))
|
|
})
|
|
mux.HandleFunc("/login/oauth/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = io.WriteString(w, fmt.Sprintf(
|
|
`{"sub":%q,"preferred_username":%q,"name":"Alice"}`, f.userSub, f.userUsername))
|
|
})
|
|
f.server = httptest.NewServer(mux)
|
|
t.Cleanup(f.server.Close)
|
|
return f
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// /oauth/register
|
|
// --------------------------------------------------------------------------
|
|
|
|
func TestRegister_HappyPath(t *testing.T) {
|
|
fx := newFixture(t)
|
|
body := `{"redirect_uris":["https://app.example.com/cb"],"client_name":"Demo"}`
|
|
resp, err := http.Post(fx.httpServer.URL+"/oauth/register",
|
|
"application/json", strings.NewReader(body))
|
|
if err != nil {
|
|
t.Fatalf("post: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusCreated {
|
|
t.Errorf("status = %d, want 201", resp.StatusCode)
|
|
}
|
|
var rr map[string]any
|
|
if err := json.NewDecoder(resp.Body).Decode(&rr); err != nil {
|
|
t.Fatalf("decode: %v", err)
|
|
}
|
|
if rr["client_id"] == "" {
|
|
t.Error("client_id empty")
|
|
}
|
|
if rr["token_endpoint_auth_method"] != "none" {
|
|
t.Errorf("auth_method = %v, want none", rr["token_endpoint_auth_method"])
|
|
}
|
|
}
|
|
|
|
func TestRegister_EmptyRedirectURIs(t *testing.T) {
|
|
fx := newFixture(t)
|
|
resp, _err := http.Post(fx.httpServer.URL+"/oauth/register",
|
|
"application/json", strings.NewReader(`{"redirect_uris":[]}`))
|
|
if _err != nil {
|
|
t.Fatalf("http: %v", _err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestRegister_BadJSON(t *testing.T) {
|
|
fx := newFixture(t)
|
|
resp, _err := http.Post(fx.httpServer.URL+"/oauth/register",
|
|
"application/json", strings.NewReader(`{not json`))
|
|
if _err != nil {
|
|
t.Fatalf("http: %v", _err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// /oauth/authorize
|
|
// --------------------------------------------------------------------------
|
|
|
|
func TestAuthorize_RedirectsToForgejo(t *testing.T) {
|
|
fx := newFixture(t)
|
|
clientID := fx.registerClient("https://app.example.com/cb")
|
|
|
|
_, challenge := pkce(t, "verifier-123")
|
|
q := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {clientID},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"state": {"client-csrf"},
|
|
"code_challenge": {challenge},
|
|
"code_challenge_method": {"S256"},
|
|
"scope": {"read:user"},
|
|
}
|
|
resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode())
|
|
if err != nil {
|
|
t.Fatalf("get: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusFound {
|
|
t.Fatalf("status = %d, want 302", resp.StatusCode)
|
|
}
|
|
loc := resp.Header.Get("Location")
|
|
if !strings.Contains(loc, fx.fakeForgejo.server.URL+"/login/oauth/authorize") {
|
|
t.Errorf("not redirected to Forgejo: %s", loc)
|
|
}
|
|
}
|
|
|
|
func TestAuthorize_UnknownClient(t *testing.T) {
|
|
fx := newFixture(t)
|
|
q := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {"nope"},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
}
|
|
resp, _err := http.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestAuthorize_MismatchedRedirectURI(t *testing.T) {
|
|
fx := newFixture(t)
|
|
clientID := fx.registerClient("https://app.example.com/cb")
|
|
q := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {clientID},
|
|
"redirect_uri": {"https://evil.example.com/cb"},
|
|
}
|
|
resp, _err := http.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestAuthorize_MissingPKCE_RedirectsWithError(t *testing.T) {
|
|
fx := newFixture(t)
|
|
clientID := fx.registerClient("https://app.example.com/cb")
|
|
q := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {clientID},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"state": {"st"},
|
|
}
|
|
resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
if resp.StatusCode != http.StatusFound {
|
|
t.Fatalf("status = %d, want 302", resp.StatusCode)
|
|
}
|
|
loc := resp.Header.Get("Location")
|
|
if !strings.Contains(loc, "error=invalid_request") {
|
|
t.Errorf("redirect missing error param: %s", loc)
|
|
}
|
|
if !strings.Contains(loc, "state=st") {
|
|
t.Errorf("redirect missing state echo: %s", loc)
|
|
}
|
|
}
|
|
|
|
func TestAuthorize_NonS256_RedirectsWithError(t *testing.T) {
|
|
fx := newFixture(t)
|
|
clientID := fx.registerClient("https://app.example.com/cb")
|
|
q := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {clientID},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"code_challenge": {"abc"},
|
|
"code_challenge_method": {"plain"},
|
|
}
|
|
resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
if resp.StatusCode != http.StatusFound {
|
|
t.Fatalf("status = %d, want 302", resp.StatusCode)
|
|
}
|
|
if !strings.Contains(resp.Header.Get("Location"), "error=invalid_request") {
|
|
t.Errorf("missing error: %s", resp.Header.Get("Location"))
|
|
}
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// /oauth/callback (also exercises the rest of the happy-path flow)
|
|
// --------------------------------------------------------------------------
|
|
|
|
// runFullFlow walks an entire authorize → callback → token sequence and
|
|
// returns the final token response. Used by happy-path tests and by
|
|
// downstream tests (refresh, revoke) that need a valid token to play with.
|
|
func runFullFlow(t *testing.T, fx *fixture, redirectURI, clientID, verifier string) tokenBundle {
|
|
t.Helper()
|
|
_, challenge := pkce(t, verifier)
|
|
clientState := "client-csrf-" + verifier
|
|
|
|
// 1. /authorize: capture the redirect to Forgejo and extract its state.
|
|
authQ := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {clientID},
|
|
"redirect_uri": {redirectURI},
|
|
"state": {clientState},
|
|
"code_challenge": {challenge},
|
|
"code_challenge_method": {"S256"},
|
|
"scope": {"read:user"},
|
|
}
|
|
resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode())
|
|
if err != nil {
|
|
t.Fatalf("authorize: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
loc, _ := url.Parse(resp.Header.Get("Location"))
|
|
forgejoState := loc.Query().Get("state")
|
|
if forgejoState == "" {
|
|
t.Fatalf("authorize did not produce a forgejo state: %s", loc)
|
|
}
|
|
|
|
// 2. /callback: simulate Forgejo redirecting back with code + state.
|
|
cbQ := url.Values{"code": {"upstream-code"}, "state": {forgejoState}}
|
|
resp, err = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + cbQ.Encode())
|
|
if err != nil {
|
|
t.Fatalf("callback: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
cbLoc, _ := url.Parse(resp.Header.Get("Location"))
|
|
brokerCode := cbLoc.Query().Get("code")
|
|
if brokerCode == "" {
|
|
t.Fatalf("callback did not return broker code: %s", cbLoc)
|
|
}
|
|
if cbLoc.Query().Get("state") != clientState {
|
|
t.Errorf("callback dropped client_state: got %q want %q",
|
|
cbLoc.Query().Get("state"), clientState)
|
|
}
|
|
|
|
// 3. /token: exchange broker code for access+refresh tokens.
|
|
form := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {brokerCode},
|
|
"client_id": {clientID},
|
|
"redirect_uri": {redirectURI},
|
|
"code_verifier": {verifier},
|
|
}
|
|
resp, err = http.PostForm(fx.httpServer.URL+"/oauth/token", form)
|
|
if err != nil {
|
|
t.Fatalf("token: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
b, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("token status = %d: %s", resp.StatusCode, b)
|
|
}
|
|
var tr tokenBundle
|
|
if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil {
|
|
t.Fatalf("decode token: %v", err)
|
|
}
|
|
return tr
|
|
}
|
|
|
|
type tokenBundle struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
}
|
|
|
|
func TestEndToEnd_RegisterAuthorizeCallbackToken(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-flow-1")
|
|
|
|
if tok.AccessToken == "" || tok.RefreshToken == "" {
|
|
t.Errorf("missing tokens: %+v", tok)
|
|
}
|
|
if tok.TokenType != "Bearer" {
|
|
t.Errorf("TokenType = %q, want Bearer", tok.TokenType)
|
|
}
|
|
}
|
|
|
|
func TestCallback_UnknownState(t *testing.T) {
|
|
fx := newFixture(t)
|
|
q := url.Values{"code": {"x"}, "state": {"unknown"}}
|
|
resp, _err := http.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// /oauth/token — additional grants and edge cases
|
|
// --------------------------------------------------------------------------
|
|
|
|
func TestToken_AuthCode_PKCEFails(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
|
|
// Authorize with one verifier, then try to exchange with a different one.
|
|
_, challenge := pkce(t, "good-verifier")
|
|
authQ := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {cid},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"state": {"st"},
|
|
"code_challenge": {challenge},
|
|
"code_challenge_method": {"S256"},
|
|
}
|
|
resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
loc, _ := url.Parse(resp.Header.Get("Location"))
|
|
forgejoState := loc.Query().Get("state")
|
|
|
|
resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" +
|
|
url.Values{"code": {"u"}, "state": {forgejoState}}.Encode())
|
|
resp.Body.Close()
|
|
cbLoc, _ := url.Parse(resp.Header.Get("Location"))
|
|
brokerCode := cbLoc.Query().Get("code")
|
|
|
|
form := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {brokerCode},
|
|
"client_id": {cid},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"code_verifier": {"WRONG-verifier"},
|
|
}
|
|
resp, err := http.PostForm(fx.httpServer.URL+"/oauth/token", form)
|
|
if err != nil {
|
|
t.Fatalf("post: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if !strings.Contains(string(body), "PKCE") && !strings.Contains(string(body), "invalid_grant") {
|
|
t.Errorf("expected PKCE error: %s", body)
|
|
}
|
|
}
|
|
|
|
func TestToken_AuthCode_Reuse(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-reuse")
|
|
|
|
// Same code can't be exchanged a second time.
|
|
// Replay the third step; we need the broker_code, but runFullFlow
|
|
// consumes it. Repeat the dance and try /token twice on the second go.
|
|
_ = tok
|
|
_, challenge := pkce(t, "verifier-reuse-2")
|
|
authQ := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {cid},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"state": {"st"},
|
|
"code_challenge": {challenge},
|
|
"code_challenge_method": {"S256"},
|
|
}
|
|
resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
loc, _ := url.Parse(resp.Header.Get("Location"))
|
|
forgejoState := loc.Query().Get("state")
|
|
|
|
resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" +
|
|
url.Values{"code": {"u"}, "state": {forgejoState}}.Encode())
|
|
resp.Body.Close()
|
|
cbLoc, _ := url.Parse(resp.Header.Get("Location"))
|
|
brokerCode := cbLoc.Query().Get("code")
|
|
|
|
form := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {brokerCode},
|
|
"client_id": {cid},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"code_verifier": {"verifier-reuse-2"},
|
|
}
|
|
first, _ := http.PostForm(fx.httpServer.URL+"/oauth/token", form)
|
|
first.Body.Close()
|
|
if first.StatusCode != http.StatusOK {
|
|
t.Fatalf("first /token status = %d", first.StatusCode)
|
|
}
|
|
second, err := http.PostForm(fx.httpServer.URL+"/oauth/token", form)
|
|
if err != nil {
|
|
t.Fatalf("second post: %v", err)
|
|
}
|
|
defer second.Body.Close()
|
|
if second.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("second /token status = %d, want 400", second.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestToken_Refresh_HappyPath(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
first := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-refresh-1")
|
|
|
|
form := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {first.RefreshToken},
|
|
"client_id": {cid},
|
|
}
|
|
resp, err := http.PostForm(fx.httpServer.URL+"/oauth/token", form)
|
|
if err != nil {
|
|
t.Fatalf("refresh: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
b, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("status = %d: %s", resp.StatusCode, b)
|
|
}
|
|
var second tokenBundle
|
|
if err := json.NewDecoder(resp.Body).Decode(&second); err != nil {
|
|
t.Fatalf("decode: %v", err)
|
|
}
|
|
if second.AccessToken == "" || second.RefreshToken == "" {
|
|
t.Errorf("missing tokens: %+v", second)
|
|
}
|
|
if second.AccessToken == first.AccessToken {
|
|
t.Error("new access token must differ from the original")
|
|
}
|
|
if second.RefreshToken == first.RefreshToken {
|
|
t.Error("new refresh token must differ (rotation)")
|
|
}
|
|
|
|
// The original refresh should now be invalid.
|
|
resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp2.Body.Close()
|
|
if resp2.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("reusing original refresh: status = %d, want 400", resp2.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestToken_UnsupportedGrant(t *testing.T) {
|
|
fx := newFixture(t)
|
|
form := url.Values{"grant_type": {"client_credentials"}}
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
b, _ := io.ReadAll(resp.Body)
|
|
if !strings.Contains(string(b), "unsupported_grant_type") {
|
|
t.Errorf("expected unsupported_grant_type: %s", b)
|
|
}
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// /oauth/revoke
|
|
// --------------------------------------------------------------------------
|
|
|
|
func TestRevoke_HappyPath(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-revoke")
|
|
|
|
form := url.Values{"token": {tok.AccessToken}, "token_type_hint": {"access_token"}}
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/revoke", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("revoke status = %d, want 200", resp.StatusCode)
|
|
}
|
|
|
|
// Refresh should still work for the unrelated refresh token here, but
|
|
// access-token validation lives in the /mcp middleware (phase 5b);
|
|
// we just verify the row got revoked by checking the DB directly.
|
|
var revokedAt int64
|
|
row := fx.store.DB().QueryRow(
|
|
`SELECT IFNULL(revoked_at, 0) FROM access_tokens WHERE token_hash =
|
|
(SELECT MIN(token_hash) FROM access_tokens WHERE client_id = ?)`,
|
|
cid)
|
|
if err := row.Scan(&revokedAt); err != nil {
|
|
t.Fatalf("scan: %v", err)
|
|
}
|
|
if revokedAt == 0 {
|
|
t.Error("access_token row was not marked revoked")
|
|
}
|
|
}
|
|
|
|
func TestRevoke_UnknownToken_StillReturns200(t *testing.T) {
|
|
// RFC 7009: the AS responds 200 even for unknown tokens, to prevent
|
|
// clients from probing token validity via the revoke endpoint.
|
|
fx := newFixture(t)
|
|
form := url.Values{"token": {"clearly-not-a-real-token"}}
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/revoke", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("status = %d, want 200", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestRevoke_MissingToken(t *testing.T) {
|
|
fx := newFixture(t)
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/revoke", url.Values{}); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// Callback error paths
|
|
// --------------------------------------------------------------------------
|
|
|
|
func TestCallback_ForgejoReportedError(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-cb-err", "client-st")
|
|
|
|
// Forgejo redirects back with ?error=access_denied — we should
|
|
// forward that to the client, not crash.
|
|
q := url.Values{"error": {"access_denied"}, "error_description": {"user denied"}, "state": {forgejoState}}
|
|
resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusFound {
|
|
t.Fatalf("status = %d, want 302", resp.StatusCode)
|
|
}
|
|
loc := resp.Header.Get("Location")
|
|
if !strings.Contains(loc, "error=access_denied") {
|
|
t.Errorf("client redirect missing access_denied: %s", loc)
|
|
}
|
|
if !strings.Contains(loc, "state=client-st") {
|
|
t.Errorf("client redirect missing original state: %s", loc)
|
|
}
|
|
}
|
|
|
|
func TestCallback_NoCodeNoError(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-cb-noc", "st")
|
|
|
|
// Empty callback (no code, no error) — should redirect with server_error.
|
|
q := url.Values{"state": {forgejoState}}
|
|
resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
if !strings.Contains(resp.Header.Get("Location"), "error=server_error") {
|
|
t.Errorf("expected server_error redirect: %s", resp.Header.Get("Location"))
|
|
}
|
|
}
|
|
|
|
func TestCallback_ForgejoExchangeFails(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-cb-fail", "st")
|
|
|
|
// Make Forgejo's token endpoint return an error.
|
|
fx.fakeForgejo.tokenStatus = http.StatusBadRequest
|
|
fx.fakeForgejo.tokenError = "invalid_grant"
|
|
|
|
q := url.Values{"code": {"u"}, "state": {forgejoState}}
|
|
resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
if !strings.Contains(resp.Header.Get("Location"), "error=server_error") {
|
|
t.Errorf("expected server_error: %s", resp.Header.Get("Location"))
|
|
}
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// Token grant error paths
|
|
// --------------------------------------------------------------------------
|
|
|
|
func TestToken_AuthCode_MissingFields(t *testing.T) {
|
|
fx := newFixture(t)
|
|
form := url.Values{"grant_type": {"authorization_code"}}
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestToken_AuthCode_BadCode(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
form := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {"made-up-code"},
|
|
"client_id": {cid},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"code_verifier": {"v"},
|
|
}
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if !strings.Contains(string(body), "invalid_grant") {
|
|
t.Errorf("expected invalid_grant: %s", body)
|
|
}
|
|
}
|
|
|
|
func TestToken_AuthCode_ExpiredCode(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
|
|
// Walk authorize + callback to mint a code, then advance the clock
|
|
// past the AuthCodeTTL before attempting /token.
|
|
verifier := "verifier-expire"
|
|
_, challenge := pkce(t, verifier)
|
|
authQ := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {cid},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"state": {"st"},
|
|
"code_challenge": {challenge},
|
|
"code_challenge_method": {"S256"},
|
|
}
|
|
resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
loc, _ := url.Parse(resp.Header.Get("Location"))
|
|
forgejoState := loc.Query().Get("state")
|
|
|
|
resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" +
|
|
url.Values{"code": {"u"}, "state": {forgejoState}}.Encode())
|
|
resp.Body.Close()
|
|
cbLoc, _ := url.Parse(resp.Header.Get("Location"))
|
|
brokerCode := cbLoc.Query().Get("code")
|
|
|
|
fx.advance(oauth.AuthCodeTTL + time.Minute)
|
|
|
|
form := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {brokerCode},
|
|
"client_id": {cid},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"code_verifier": {verifier},
|
|
}
|
|
resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp2.Body.Close()
|
|
body, _ := io.ReadAll(resp2.Body)
|
|
if resp2.StatusCode != http.StatusBadRequest || !strings.Contains(string(body), "expired") {
|
|
t.Errorf("expected expired error, got %d: %s", resp2.StatusCode, body)
|
|
}
|
|
}
|
|
|
|
func TestToken_AuthCode_WrongClient(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
cid2 := fx.registerClient("https://other.example.com/cb")
|
|
|
|
// Run authorize/callback for cid, then try /token under cid2.
|
|
verifier := "verifier-wrongclient"
|
|
_, challenge := pkce(t, verifier)
|
|
authQ := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {cid},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"state": {"st"},
|
|
"code_challenge": {challenge},
|
|
"code_challenge_method": {"S256"},
|
|
}
|
|
resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
loc, _ := url.Parse(resp.Header.Get("Location"))
|
|
forgejoState := loc.Query().Get("state")
|
|
resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" +
|
|
url.Values{"code": {"u"}, "state": {forgejoState}}.Encode())
|
|
resp.Body.Close()
|
|
cbLoc, _ := url.Parse(resp.Header.Get("Location"))
|
|
brokerCode := cbLoc.Query().Get("code")
|
|
|
|
form := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {brokerCode},
|
|
"client_id": {cid2},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"code_verifier": {verifier},
|
|
}
|
|
resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp2.Body.Close()
|
|
body, _ := io.ReadAll(resp2.Body)
|
|
if !strings.Contains(string(body), "client_id mismatch") {
|
|
t.Errorf("expected client_id mismatch: %s", body)
|
|
}
|
|
}
|
|
|
|
func TestToken_Refresh_MissingFields(t *testing.T) {
|
|
fx := newFixture(t)
|
|
form := url.Values{"grant_type": {"refresh_token"}}
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestToken_Refresh_UnknownToken(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
form := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {"made-up-token"},
|
|
"client_id": {cid},
|
|
}
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestToken_Refresh_WrongClientID(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
cid2 := fx.registerClient("https://other.example.com/cb")
|
|
tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-rfwrong")
|
|
|
|
form := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {tok.RefreshToken},
|
|
"client_id": {cid2},
|
|
}
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp.Body.Close()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if !strings.Contains(string(body), "client_id mismatch") {
|
|
t.Errorf("expected client_id mismatch: %s", body)
|
|
}
|
|
}
|
|
|
|
func TestToken_Refresh_RevokedToken(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-rfrev")
|
|
|
|
// Revoke the refresh token via /oauth/revoke.
|
|
revForm := url.Values{"token": {tok.RefreshToken}, "token_type_hint": {"refresh_token"}}
|
|
revResp, _ := http.PostForm(fx.httpServer.URL+"/oauth/revoke", revForm)
|
|
revResp.Body.Close()
|
|
|
|
// Subsequent refresh attempt should fail.
|
|
form := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {tok.RefreshToken},
|
|
"client_id": {cid},
|
|
}
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp.Body.Close()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if !strings.Contains(string(body), "revoked") {
|
|
t.Errorf("expected revoked error: %s", body)
|
|
}
|
|
}
|
|
|
|
func TestToken_Refresh_Expired(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
tok := runFullFlow(t, fx, "https://app.example.com/cb", cid, "verifier-rfexp")
|
|
|
|
fx.advance(oauth.RefreshTokenTTL + time.Hour)
|
|
|
|
form := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {tok.RefreshToken},
|
|
"client_id": {cid},
|
|
}
|
|
resp, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp.Body.Close()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if !strings.Contains(string(body), "expired") {
|
|
t.Errorf("expected expired error: %s", body)
|
|
}
|
|
}
|
|
|
|
func TestToken_AuthCode_WrongRedirectURI(t *testing.T) {
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
verifier := "verifier-wrongru"
|
|
_, challenge := pkce(t, verifier)
|
|
authQ := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {cid},
|
|
"redirect_uri": {"https://app.example.com/cb"},
|
|
"state": {"st"},
|
|
"code_challenge": {challenge},
|
|
"code_challenge_method": {"S256"},
|
|
}
|
|
resp, _err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + authQ.Encode()); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
resp.Body.Close()
|
|
loc, _ := url.Parse(resp.Header.Get("Location"))
|
|
resp, _ = noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" +
|
|
url.Values{"code": {"u"}, "state": {loc.Query().Get("state")}}.Encode())
|
|
resp.Body.Close()
|
|
cbLoc, _ := url.Parse(resp.Header.Get("Location"))
|
|
brokerCode := cbLoc.Query().Get("code")
|
|
|
|
form := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {brokerCode},
|
|
"client_id": {cid},
|
|
"redirect_uri": {"https://different.example.com/cb"},
|
|
"code_verifier": {verifier},
|
|
}
|
|
resp2, _err := http.PostForm(fx.httpServer.URL+"/oauth/token", form); if _err != nil { t.Fatalf("http: %v", _err) }
|
|
defer resp2.Body.Close()
|
|
body, _ := io.ReadAll(resp2.Body)
|
|
if !strings.Contains(string(body), "redirect_uri mismatch") {
|
|
t.Errorf("expected redirect_uri mismatch: %s", body)
|
|
}
|
|
}
|
|
|
|
func TestToken_NoGrantType(t *testing.T) {
|
|
fx := newFixture(t)
|
|
resp, err := http.PostForm(fx.httpServer.URL+"/oauth/token", url.Values{})
|
|
if err != nil {
|
|
t.Fatalf("post: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// Pending-auth reaper
|
|
// --------------------------------------------------------------------------
|
|
|
|
func TestReap_RemovesExpiredPending(t *testing.T) {
|
|
// Internal-API path: we can't call reapPending directly from the test
|
|
// package, but we can drive the same effect via /callback's "unknown
|
|
// state" path after time has advanced past PendingAuthTTL.
|
|
fx := newFixture(t)
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
|
forgejoState := startAuthorize(t, fx, cid, "https://app.example.com/cb", "verifier-reap", "st")
|
|
|
|
fx.advance(oauth.PendingAuthTTL + time.Minute)
|
|
|
|
// /callback with the (now-expired) state — pending entry's expiry is
|
|
// rejected directly. Either the reaper has cleaned it (unknown state)
|
|
// or the freshness check rejects it; both produce a 400.
|
|
q := url.Values{"code": {"u"}, "state": {forgejoState}}
|
|
resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/callback?" + q.Encode())
|
|
if err != nil {
|
|
t.Fatalf("get: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// Discovery (.well-known)
|
|
// --------------------------------------------------------------------------
|
|
|
|
func TestDiscovery_AuthorizationServerMetadata(t *testing.T) {
|
|
fx := newFixture(t)
|
|
resp, err := http.Get(fx.httpServer.URL + "/.well-known/oauth-authorization-server")
|
|
if err != nil {
|
|
t.Fatalf("get: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("status = %d, want 200", resp.StatusCode)
|
|
}
|
|
if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/json") {
|
|
t.Errorf("Content-Type = %q, want application/json", ct)
|
|
}
|
|
|
|
var md map[string]any
|
|
if err := json.NewDecoder(resp.Body).Decode(&md); err != nil {
|
|
t.Fatalf("decode: %v", err)
|
|
}
|
|
|
|
// Issuer MUST come from cfg.Issuer, not from the request URL.
|
|
if md["issuer"] != "https://broker.example.com" {
|
|
t.Errorf("issuer = %v, want https://broker.example.com (config)", md["issuer"])
|
|
}
|
|
for _, k := range []string{"authorization_endpoint", "token_endpoint", "registration_endpoint", "revocation_endpoint"} {
|
|
v, _ := md[k].(string)
|
|
if !strings.HasPrefix(v, "https://broker.example.com/") {
|
|
t.Errorf("%s = %q, want issuer-rooted URL", k, v)
|
|
}
|
|
}
|
|
|
|
wantContains := map[string][]string{
|
|
"code_challenge_methods_supported": {"S256"},
|
|
"grant_types_supported": {"authorization_code", "refresh_token"},
|
|
"response_types_supported": {"code"},
|
|
"token_endpoint_auth_methods_supported": {"none"},
|
|
}
|
|
for field, vals := range wantContains {
|
|
got, _ := md[field].([]any)
|
|
if len(got) == 0 {
|
|
t.Errorf("%s missing or empty", field)
|
|
continue
|
|
}
|
|
for _, want := range vals {
|
|
found := false
|
|
for _, g := range got {
|
|
if g == want {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Errorf("%s does not include %q (got %v)", field, want, got)
|
|
}
|
|
}
|
|
}
|
|
|
|
scopes, _ := md["scopes_supported"].([]any)
|
|
if len(scopes) != 2 {
|
|
t.Errorf("scopes_supported = %v, want 2 entries", scopes)
|
|
}
|
|
}
|
|
|
|
func TestDiscovery_ProtectedResourceMetadata(t *testing.T) {
|
|
fx := newFixture(t)
|
|
resp, err := http.Get(fx.httpServer.URL + "/.well-known/oauth-protected-resource")
|
|
if err != nil {
|
|
t.Fatalf("get: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("status = %d, want 200", resp.StatusCode)
|
|
}
|
|
|
|
var md map[string]any
|
|
if err := json.NewDecoder(resp.Body).Decode(&md); err != nil {
|
|
t.Fatalf("decode: %v", err)
|
|
}
|
|
if md["resource"] != "https://broker.example.com/mcp" {
|
|
t.Errorf("resource = %v, want issuer-rooted /mcp", md["resource"])
|
|
}
|
|
servers, _ := md["authorization_servers"].([]any)
|
|
if len(servers) != 1 || servers[0] != "https://broker.example.com" {
|
|
t.Errorf("authorization_servers = %v, want [config issuer]", servers)
|
|
}
|
|
bearer, _ := md["bearer_methods_supported"].([]any)
|
|
if len(bearer) == 0 || bearer[0] != "header" {
|
|
t.Errorf("bearer_methods_supported = %v, want [\"header\"]", bearer)
|
|
}
|
|
}
|
|
|
|
func TestDiscovery_IssuerIgnoresHostHeader(t *testing.T) {
|
|
// Crafting a malicious Host header must not leak that host into the
|
|
// discovery document. Defense against the metadata-spoofing class of
|
|
// OAuth attack starts here.
|
|
fx := newFixture(t)
|
|
req, err := http.NewRequest(http.MethodGet,
|
|
fx.httpServer.URL+"/.well-known/oauth-authorization-server", nil)
|
|
if err != nil {
|
|
t.Fatalf("new request: %v", err)
|
|
}
|
|
req.Host = "evil.example.com"
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("do: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if strings.Contains(string(body), "evil.example.com") {
|
|
t.Errorf("discovery doc leaked attacker-supplied Host: %s", body)
|
|
}
|
|
}
|
|
|
|
// --------------------------------------------------------------------------
|
|
// Helpers used by /callback flow tests
|
|
// --------------------------------------------------------------------------
|
|
|
|
// startAuthorize walks just steps 1+2 of the flow (client → broker /authorize)
|
|
// and returns the forgejo state the broker stashed. Used by tests that need
|
|
// to drive /callback with arbitrary parameters afterwards.
|
|
func startAuthorize(t *testing.T, fx *fixture, clientID, redirectURI, verifier, clientState string) string {
|
|
t.Helper()
|
|
_, challenge := pkce(t, verifier)
|
|
q := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {clientID},
|
|
"redirect_uri": {redirectURI},
|
|
"state": {clientState},
|
|
"code_challenge": {challenge},
|
|
"code_challenge_method": {"S256"},
|
|
}
|
|
resp, err := noRedirectClient.Get(fx.httpServer.URL + "/oauth/authorize?" + q.Encode())
|
|
if err != nil {
|
|
t.Fatalf("authorize: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
loc, _ := url.Parse(resp.Header.Get("Location"))
|
|
state := loc.Query().Get("state")
|
|
if state == "" {
|
|
t.Fatalf("no forgejo state in redirect: %s", loc)
|
|
}
|
|
return state
|
|
}
|
|
|
|
func TestNewServer_ValidationErrors(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
cfg oauth.Config
|
|
want string
|
|
}{
|
|
{"no_store", oauth.Config{Forgejo: &forgejo.Client{}, Issuer: "https://x"}, "Store"},
|
|
{"no_forgejo", oauth.Config{Store: &store.Store{}, Issuer: "https://x"}, "Forgejo"},
|
|
{"no_issuer", oauth.Config{Store: &store.Store{}, Forgejo: &forgejo.Client{}}, "Issuer"},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
_, err := oauth.NewServer(tc.cfg)
|
|
if err == nil || !strings.Contains(err.Error(), tc.want) {
|
|
t.Errorf("want error containing %q, got %v", tc.want, err)
|
|
}
|
|
})
|
|
}
|
|
}
|