package oauth_test import ( "io" "net/http" "net/http/httptest" "strings" "testing" "time" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth" ) // authFixture wraps the OAuth fixture and exposes a fresh Authenticator // pointed at the same store and clock. type authFixture struct { *fixture auth *oauth.Authenticator } func newAuthFixture(t *testing.T) *authFixture { t.Helper() fx := newFixture(t) return &authFixture{ fixture: fx, auth: &oauth.Authenticator{Store: fx.store, Now: fx.now}, } } // echoHandler reads the Session from context and writes a recognisable // payload, so tests can confirm the right Session reached the handler. var echoHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sess, ok := oauth.SessionFromContext(r.Context()) if !ok { http.Error(w, "no session", http.StatusInternalServerError) return } _, _ = io.WriteString(w, "ok username="+sess.ForgejoUsername+ " client="+sess.ClientID+ " forgejo_token="+sess.ForgejoToken) }) func TestRequireBearer_ValidTokenPasses(t *testing.T) { fx := newAuthFixture(t) cid := fx.registerClient("https://app.example.com/cb") tok := runFullFlow(t, fx.fixture, "https://app.example.com/cb", cid, "verifier-auth-1") srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler)) t.Cleanup(srv.Close) req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) req.Header.Set("Authorization", "Bearer "+tok.AccessToken) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) t.Fatalf("status = %d, want 200; body: %s", resp.StatusCode, body) } body, _ := io.ReadAll(resp.Body) if !strings.Contains(string(body), "username=alice") { t.Errorf("session not surfaced correctly: %s", body) } if !strings.Contains(string(body), "forgejo_token=fj-access") { t.Errorf("forgejo token not in session: %s", body) } if !strings.Contains(string(body), "client="+cid) { t.Errorf("client_id not in session: %s", body) } } func TestRequireBearer_NoHeader_401(t *testing.T) { fx := newAuthFixture(t) srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler)) t.Cleanup(srv.Close) resp, err := http.Get(srv.URL) if err != nil { t.Fatalf("get: %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusUnauthorized { t.Errorf("status = %d, want 401", resp.StatusCode) } if got := resp.Header.Get("WWW-Authenticate"); !strings.Contains(got, "invalid_request") { t.Errorf("WWW-Authenticate = %q, want invalid_request", got) } } func TestRequireBearer_WrongScheme_401(t *testing.T) { fx := newAuthFixture(t) srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler)) t.Cleanup(srv.Close) req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) req.Header.Set("Authorization", "Basic abc==") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusUnauthorized { t.Errorf("status = %d, want 401", resp.StatusCode) } } func TestRequireBearer_EmptyToken_401(t *testing.T) { fx := newAuthFixture(t) srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler)) t.Cleanup(srv.Close) req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) req.Header.Set("Authorization", "Bearer ") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusUnauthorized { t.Errorf("status = %d, want 401", resp.StatusCode) } } func TestRequireBearer_UnknownToken_401(t *testing.T) { fx := newAuthFixture(t) srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler)) t.Cleanup(srv.Close) req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) req.Header.Set("Authorization", "Bearer made-up-not-in-store") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusUnauthorized { t.Errorf("status = %d, want 401", resp.StatusCode) } if got := resp.Header.Get("WWW-Authenticate"); !strings.Contains(got, "invalid_token") { t.Errorf("WWW-Authenticate = %q, want invalid_token", got) } } func TestRequireBearer_ExpiredToken_401(t *testing.T) { fx := newAuthFixture(t) cid := fx.registerClient("https://app.example.com/cb") tok := runFullFlow(t, fx.fixture, "https://app.example.com/cb", cid, "verifier-auth-exp") // Push the clock past the access-token lifetime. fx.advance(oauth.AccessTokenTTL + time.Minute) srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler)) t.Cleanup(srv.Close) req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) req.Header.Set("Authorization", "Bearer "+tok.AccessToken) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusUnauthorized { t.Errorf("status = %d, want 401", resp.StatusCode) } if !strings.Contains(resp.Header.Get("WWW-Authenticate"), "expired") { t.Errorf("WWW-Authenticate missing expired reason: %q", resp.Header.Get("WWW-Authenticate")) } } func TestRequireBearer_RevokedToken_401(t *testing.T) { fx := newAuthFixture(t) cid := fx.registerClient("https://app.example.com/cb") tok := runFullFlow(t, fx.fixture, "https://app.example.com/cb", cid, "verifier-auth-rev") // Revoke through the public /oauth/revoke endpoint. form := strings.NewReader("token=" + tok.AccessToken + "&token_type_hint=access_token") revResp, err := http.Post(fx.httpServer.URL+"/oauth/revoke", "application/x-www-form-urlencoded", form) if err != nil { t.Fatalf("revoke: %v", err) } revResp.Body.Close() srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler)) t.Cleanup(srv.Close) req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) req.Header.Set("Authorization", "Bearer "+tok.AccessToken) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusUnauthorized { t.Errorf("status = %d, want 401", resp.StatusCode) } if !strings.Contains(resp.Header.Get("WWW-Authenticate"), "revoked") { t.Errorf("WWW-Authenticate missing revoked reason: %q", resp.Header.Get("WWW-Authenticate")) } } func TestSessionFromContext_NotPresent(t *testing.T) { if _, ok := oauth.SessionFromContext(t.Context()); ok { t.Error("SessionFromContext should return false on a context with no session attached") } }