forgejo-mcp-broker/internal/forgejo/forgejo_test.go

390 lines
12 KiB
Go
Raw Normal View History

package forgejo_test
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/forgejo"
)
// newTestClient returns a Client pointed at the given test server URL with
// well-known credentials. Callers can override individual ClientConfig
// fields by setting them on the returned config before calling NewClient
// themselves; this helper just keeps the boilerplate down.
func newTestClient(t *testing.T, baseURL string) *forgejo.Client {
t.Helper()
c, err := forgejo.NewClient(forgejo.ClientConfig{
BaseURL: baseURL,
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
UserAgent: "test-broker",
})
if err != nil {
t.Fatalf("NewClient: %v", err)
}
return c
}
func TestNewClient_ValidationErrors(t *testing.T) {
cases := []struct {
name string
cfg forgejo.ClientConfig
want string
}{
{"no_base_url", forgejo.ClientConfig{ClientID: "id", ClientSecret: "s"}, "BaseURL"},
{"no_client_id", forgejo.ClientConfig{BaseURL: "https://x", ClientSecret: "s"}, "ClientID"},
{"no_client_secret", forgejo.ClientConfig{BaseURL: "https://x", ClientID: "id"}, "ClientSecret"},
{"bad_scheme", forgejo.ClientConfig{BaseURL: "ftp://x", ClientID: "id", ClientSecret: "s"}, "http(s)"},
{"no_host", forgejo.ClientConfig{BaseURL: "https://", ClientID: "id", ClientSecret: "s"}, "missing host"},
{"unparseable_url", forgejo.ClientConfig{BaseURL: "://nope", ClientID: "id", ClientSecret: "s"}, "parse BaseURL"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, err := forgejo.NewClient(tc.cfg)
if err == nil || !strings.Contains(err.Error(), tc.want) {
t.Errorf("want error containing %q, got %v", tc.want, err)
}
})
}
}
func TestNewClient_DefaultsApplied(t *testing.T) {
c, err := forgejo.NewClient(forgejo.ClientConfig{
BaseURL: "https://forgejo.example.com", ClientID: "id", ClientSecret: "s",
})
if err != nil {
t.Fatalf("NewClient: %v", err)
}
// Smoke-test by building a URL — defaults don't have a public getter.
u := c.AuthorizeURL(forgejo.AuthorizeURLOptions{
RedirectURI: "https://x/cb", State: "st", CodeChallenge: "cc", CodeChallengeMethod: "S256",
})
if !strings.HasPrefix(u, "https://forgejo.example.com/login/oauth/authorize?") {
t.Errorf("authorize URL prefix wrong: %s", u)
}
}
func TestAuthorizeURL_AllParamsPresent(t *testing.T) {
c := newTestClient(t, "https://forgejo.example.com")
out := c.AuthorizeURL(forgejo.AuthorizeURLOptions{
RedirectURI: "https://broker.example.com/oauth/callback",
State: "csrf-token",
Scopes: "read:user write:repository",
CodeChallenge: "challenge-string",
CodeChallengeMethod: "S256",
})
u, err := url.Parse(out)
if err != nil {
t.Fatalf("parse: %v", err)
}
q := u.Query()
want := map[string]string{
"response_type": "code",
"client_id": "test-client-id",
"redirect_uri": "https://broker.example.com/oauth/callback",
"state": "csrf-token",
"scope": "read:user write:repository",
"code_challenge": "challenge-string",
"code_challenge_method": "S256",
}
for k, v := range want {
if q.Get(k) != v {
t.Errorf("query[%q] = %q, want %q", k, q.Get(k), v)
}
}
if u.Path != "/login/oauth/authorize" {
t.Errorf("path = %q, want /login/oauth/authorize", u.Path)
}
}
func TestAuthorizeURL_OmitsScopeWhenEmpty(t *testing.T) {
c := newTestClient(t, "https://forgejo.example.com")
out := c.AuthorizeURL(forgejo.AuthorizeURLOptions{
RedirectURI: "https://x/cb", State: "s", CodeChallenge: "c", CodeChallengeMethod: "S256",
})
u, _ := url.Parse(out)
if u.Query().Has("scope") {
t.Errorf("scope should not appear when empty: %s", out)
}
}
func TestAuthorizeURL_BaseWithTrailingSlash(t *testing.T) {
// Trailing slash on BaseURL must not cause a double-slash path.
c := newTestClient(t, "https://forgejo.example.com/")
out := c.AuthorizeURL(forgejo.AuthorizeURLOptions{
RedirectURI: "https://x/cb", State: "s", CodeChallenge: "c", CodeChallengeMethod: "S256",
})
if strings.Contains(out, "com//login") {
t.Errorf("double slash in URL: %s", out)
}
}
// fakeForgejo wraps an httptest.Server with handler injection points so each
// test can shape the response without rewriting the boilerplate.
type fakeForgejo struct {
t *testing.T
server *httptest.Server
tokenStatus int
tokenBody string
userStatus int
userBody string
lastForm url.Values // populated after a token endpoint hit
lastAuth string // populated after a userinfo hit
}
func newFakeForgejo(t *testing.T) *fakeForgejo {
t.Helper()
f := &fakeForgejo{
t: t,
tokenStatus: http.StatusOK,
userStatus: http.StatusOK,
}
mux := http.NewServeMux()
mux.HandleFunc("/login/oauth/access_token", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
form, _ := url.ParseQuery(string(body))
f.lastForm = form
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(f.tokenStatus)
_, _ = io.WriteString(w, f.tokenBody)
})
mux.HandleFunc("/login/oauth/userinfo", func(w http.ResponseWriter, r *http.Request) {
f.lastAuth = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(f.userStatus)
_, _ = io.WriteString(w, f.userBody)
})
f.server = httptest.NewServer(mux)
t.Cleanup(f.server.Close)
return f
}
func TestExchangeCode_Success(t *testing.T) {
f := newFakeForgejo(t)
f.tokenBody = `{"access_token":"a","refresh_token":"r","token_type":"bearer","expires_in":3600,"scope":"read:user"}`
c := newTestClient(t, f.server.URL)
tok, err := c.ExchangeCode(t.Context(), "the-code", "the-verifier", "https://broker.example.com/oauth/callback")
if err != nil {
t.Fatalf("ExchangeCode: %v", err)
}
if tok.AccessToken != "a" || tok.RefreshToken != "r" || tok.ExpiresIn != 3600 {
t.Errorf("token mismatch: %+v", tok)
}
// Verify the form params Forgejo received.
expected := map[string]string{
"grant_type": "authorization_code",
"code": "the-code",
"code_verifier": "the-verifier",
"redirect_uri": "https://broker.example.com/oauth/callback",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
}
for k, v := range expected {
if got := f.lastForm.Get(k); got != v {
t.Errorf("form[%q] = %q, want %q", k, got, v)
}
}
}
func TestExchangeCode_OAuthError(t *testing.T) {
f := newFakeForgejo(t)
f.tokenStatus = http.StatusBadRequest
f.tokenBody = `{"error":"invalid_grant","error_description":"code expired"}`
c := newTestClient(t, f.server.URL)
_, err := c.ExchangeCode(t.Context(), "x", "v", "https://x/cb")
var e *forgejo.Error
if !errors.As(err, &e) {
t.Fatalf("want *forgejo.Error, got %T: %v", err, err)
}
if e.Code != "invalid_grant" {
t.Errorf("Code = %q, want invalid_grant", e.Code)
}
if e.HTTPStatus != http.StatusBadRequest {
t.Errorf("HTTPStatus = %d, want %d", e.HTTPStatus, http.StatusBadRequest)
}
if !strings.Contains(e.Error(), "expired") {
t.Errorf("Error string missing description: %v", e)
}
}
func TestExchangeCode_OAuthErrorWithoutDescription(t *testing.T) {
f := newFakeForgejo(t)
f.tokenStatus = http.StatusBadRequest
f.tokenBody = `{"error":"invalid_request"}`
c := newTestClient(t, f.server.URL)
_, err := c.ExchangeCode(t.Context(), "x", "v", "https://x/cb")
var e *forgejo.Error
if !errors.As(err, &e) || strings.Contains(e.Error(), ":") == false {
t.Errorf("expected formatted error, got %v", err)
}
}
func TestExchangeCode_5xx_NoBody(t *testing.T) {
f := newFakeForgejo(t)
f.tokenStatus = http.StatusInternalServerError
f.tokenBody = "<html>oops</html>"
c := newTestClient(t, f.server.URL)
_, err := c.ExchangeCode(t.Context(), "x", "v", "https://x/cb")
var e *forgejo.Error
if !errors.As(err, &e) {
t.Fatalf("want *forgejo.Error, got %v", err)
}
if e.HTTPStatus != http.StatusInternalServerError {
t.Errorf("HTTPStatus = %d", e.HTTPStatus)
}
}
func TestExchangeCode_MalformedJSON(t *testing.T) {
f := newFakeForgejo(t)
f.tokenStatus = http.StatusOK
f.tokenBody = "not json"
c := newTestClient(t, f.server.URL)
_, err := c.ExchangeCode(t.Context(), "x", "v", "https://x/cb")
if err == nil || !strings.Contains(err.Error(), "decode") {
t.Errorf("want decode error, got %v", err)
}
}
func TestExchangeCode_MissingAccessToken(t *testing.T) {
f := newFakeForgejo(t)
f.tokenBody = `{"token_type":"bearer","expires_in":1}`
c := newTestClient(t, f.server.URL)
_, err := c.ExchangeCode(t.Context(), "x", "v", "https://x/cb")
if err == nil || !strings.Contains(err.Error(), "access_token") {
t.Errorf("want missing access_token error, got %v", err)
}
}
func TestExchangeCode_NetworkError(t *testing.T) {
f := newFakeForgejo(t)
c := newTestClient(t, f.server.URL)
f.server.Close() // force network error on next call
_, err := c.ExchangeCode(t.Context(), "x", "v", "https://x/cb")
if err == nil {
t.Fatal("expected network error")
}
// Should not be a structured *forgejo.Error — those are for upstream
// rejections, not transport failures.
var e *forgejo.Error
if errors.As(err, &e) {
t.Errorf("network errors must not surface as *forgejo.Error: %v", err)
}
}
func TestRefresh_Success(t *testing.T) {
f := newFakeForgejo(t)
f.tokenBody = `{"access_token":"a2","refresh_token":"r2","token_type":"bearer","expires_in":3600}`
c := newTestClient(t, f.server.URL)
tok, err := c.Refresh(t.Context(), "old-refresh-token")
if err != nil {
t.Fatalf("Refresh: %v", err)
}
if tok.AccessToken != "a2" || tok.RefreshToken != "r2" {
t.Errorf("token mismatch: %+v", tok)
}
if f.lastForm.Get("grant_type") != "refresh_token" {
t.Errorf("grant_type = %q, want refresh_token", f.lastForm.Get("grant_type"))
}
if f.lastForm.Get("refresh_token") != "old-refresh-token" {
t.Errorf("refresh_token form param = %q", f.lastForm.Get("refresh_token"))
}
}
func TestRefresh_InvalidGrant(t *testing.T) {
f := newFakeForgejo(t)
f.tokenStatus = http.StatusBadRequest
f.tokenBody = `{"error":"invalid_grant","error_description":"refresh token revoked"}`
c := newTestClient(t, f.server.URL)
_, err := c.Refresh(t.Context(), "stale")
var e *forgejo.Error
if !errors.As(err, &e) || e.Code != "invalid_grant" {
t.Errorf("want invalid_grant, got %v", err)
}
}
func TestFetchUserInfo_Success(t *testing.T) {
f := newFakeForgejo(t)
f.userBody = `{"sub":"42","preferred_username":"alice","name":"Alice Bee","email":"alice@example.com"}`
c := newTestClient(t, f.server.URL)
ui, err := c.FetchUserInfo(t.Context(), "the-access-token")
if err != nil {
t.Fatalf("FetchUserInfo: %v", err)
}
if ui.Sub != "42" || ui.PreferredUsername != "alice" {
t.Errorf("user mismatch: %+v", ui)
}
if f.lastAuth != "Bearer the-access-token" {
t.Errorf("Authorization header = %q", f.lastAuth)
}
}
func TestFetchUserInfo_OAuthError(t *testing.T) {
f := newFakeForgejo(t)
f.userStatus = http.StatusUnauthorized
f.userBody = `{"error":"invalid_token","error_description":"expired"}`
c := newTestClient(t, f.server.URL)
_, err := c.FetchUserInfo(t.Context(), "x")
var e *forgejo.Error
if !errors.As(err, &e) || e.Code != "invalid_token" {
t.Errorf("want invalid_token, got %v", err)
}
}
func TestFetchUserInfo_NonOAuthError(t *testing.T) {
// Some Forgejo versions return {"message": "..."} for 401 instead of the
// RFC 6749 oauth-error shape. Verify we still return a structured error.
f := newFakeForgejo(t)
f.userStatus = http.StatusUnauthorized
f.userBody = `{"message":"unauthenticated"}`
c := newTestClient(t, f.server.URL)
_, err := c.FetchUserInfo(t.Context(), "x")
var e *forgejo.Error
if !errors.As(err, &e) {
t.Fatalf("want *forgejo.Error, got %v", err)
}
if e.Code != "userinfo_failed" {
t.Errorf("Code = %q, want userinfo_failed", e.Code)
}
if e.HTTPStatus != http.StatusUnauthorized {
t.Errorf("HTTPStatus = %d", e.HTTPStatus)
}
}
func TestFetchUserInfo_MissingSub(t *testing.T) {
f := newFakeForgejo(t)
f.userBody = `{"preferred_username":"alice"}`
c := newTestClient(t, f.server.URL)
_, err := c.FetchUserInfo(t.Context(), "x")
if err == nil || !strings.Contains(err.Error(), "sub") {
t.Errorf("want missing-sub error, got %v", err)
}
}
func TestFetchUserInfo_MalformedJSON(t *testing.T) {
f := newFakeForgejo(t)
f.userBody = "not json"
c := newTestClient(t, f.server.URL)
_, err := c.FetchUserInfo(t.Context(), "x")
if err == nil || !strings.Contains(err.Error(), "decode") {
t.Errorf("want decode error, got %v", err)
}
}