208 lines
6.3 KiB
Go
208 lines
6.3 KiB
Go
|
|
package oauth_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
|
||
|
|
)
|
||
|
|
|
||
|
|
// authFixture wraps the OAuth fixture and exposes a fresh Authenticator
|
||
|
|
// pointed at the same store and clock.
|
||
|
|
type authFixture struct {
|
||
|
|
*fixture
|
||
|
|
auth *oauth.Authenticator
|
||
|
|
}
|
||
|
|
|
||
|
|
func newAuthFixture(t *testing.T) *authFixture {
|
||
|
|
t.Helper()
|
||
|
|
fx := newFixture(t)
|
||
|
|
return &authFixture{
|
||
|
|
fixture: fx,
|
||
|
|
auth: &oauth.Authenticator{Store: fx.store, Now: fx.now},
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// echoHandler reads the Session from context and writes a recognisable
|
||
|
|
// payload, so tests can confirm the right Session reached the handler.
|
||
|
|
var echoHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
sess, ok := oauth.SessionFromContext(r.Context())
|
||
|
|
if !ok {
|
||
|
|
http.Error(w, "no session", http.StatusInternalServerError)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
_, _ = io.WriteString(w, "ok username="+sess.ForgejoUsername+
|
||
|
|
" client="+sess.ClientID+
|
||
|
|
" forgejo_token="+sess.ForgejoToken)
|
||
|
|
})
|
||
|
|
|
||
|
|
func TestRequireBearer_ValidTokenPasses(t *testing.T) {
|
||
|
|
fx := newAuthFixture(t)
|
||
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
||
|
|
tok := runFullFlow(t, fx.fixture, "https://app.example.com/cb", cid, "verifier-auth-1")
|
||
|
|
|
||
|
|
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||
|
|
t.Cleanup(srv.Close)
|
||
|
|
|
||
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||
|
|
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||
|
|
resp, err := http.DefaultClient.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("do: %v", err)
|
||
|
|
}
|
||
|
|
defer resp.Body.Close()
|
||
|
|
|
||
|
|
if resp.StatusCode != http.StatusOK {
|
||
|
|
body, _ := io.ReadAll(resp.Body)
|
||
|
|
t.Fatalf("status = %d, want 200; body: %s", resp.StatusCode, body)
|
||
|
|
}
|
||
|
|
body, _ := io.ReadAll(resp.Body)
|
||
|
|
if !strings.Contains(string(body), "username=alice") {
|
||
|
|
t.Errorf("session not surfaced correctly: %s", body)
|
||
|
|
}
|
||
|
|
if !strings.Contains(string(body), "forgejo_token=fj-access") {
|
||
|
|
t.Errorf("forgejo token not in session: %s", body)
|
||
|
|
}
|
||
|
|
if !strings.Contains(string(body), "client="+cid) {
|
||
|
|
t.Errorf("client_id not in session: %s", body)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequireBearer_NoHeader_401(t *testing.T) {
|
||
|
|
fx := newAuthFixture(t)
|
||
|
|
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||
|
|
t.Cleanup(srv.Close)
|
||
|
|
|
||
|
|
resp, err := http.Get(srv.URL)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("get: %v", err)
|
||
|
|
}
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
||
|
|
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||
|
|
}
|
||
|
|
if got := resp.Header.Get("WWW-Authenticate"); !strings.Contains(got, "invalid_request") {
|
||
|
|
t.Errorf("WWW-Authenticate = %q, want invalid_request", got)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequireBearer_WrongScheme_401(t *testing.T) {
|
||
|
|
fx := newAuthFixture(t)
|
||
|
|
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||
|
|
t.Cleanup(srv.Close)
|
||
|
|
|
||
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||
|
|
req.Header.Set("Authorization", "Basic abc==")
|
||
|
|
resp, err := http.DefaultClient.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("do: %v", err)
|
||
|
|
}
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
||
|
|
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequireBearer_EmptyToken_401(t *testing.T) {
|
||
|
|
fx := newAuthFixture(t)
|
||
|
|
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||
|
|
t.Cleanup(srv.Close)
|
||
|
|
|
||
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||
|
|
req.Header.Set("Authorization", "Bearer ")
|
||
|
|
resp, err := http.DefaultClient.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("do: %v", err)
|
||
|
|
}
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
||
|
|
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequireBearer_UnknownToken_401(t *testing.T) {
|
||
|
|
fx := newAuthFixture(t)
|
||
|
|
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||
|
|
t.Cleanup(srv.Close)
|
||
|
|
|
||
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||
|
|
req.Header.Set("Authorization", "Bearer made-up-not-in-store")
|
||
|
|
resp, err := http.DefaultClient.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("do: %v", err)
|
||
|
|
}
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
||
|
|
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||
|
|
}
|
||
|
|
if got := resp.Header.Get("WWW-Authenticate"); !strings.Contains(got, "invalid_token") {
|
||
|
|
t.Errorf("WWW-Authenticate = %q, want invalid_token", got)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequireBearer_ExpiredToken_401(t *testing.T) {
|
||
|
|
fx := newAuthFixture(t)
|
||
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
||
|
|
tok := runFullFlow(t, fx.fixture, "https://app.example.com/cb", cid, "verifier-auth-exp")
|
||
|
|
|
||
|
|
// Push the clock past the access-token lifetime.
|
||
|
|
fx.advance(oauth.AccessTokenTTL + time.Minute)
|
||
|
|
|
||
|
|
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||
|
|
t.Cleanup(srv.Close)
|
||
|
|
|
||
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||
|
|
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||
|
|
resp, err := http.DefaultClient.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("do: %v", err)
|
||
|
|
}
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
||
|
|
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||
|
|
}
|
||
|
|
if !strings.Contains(resp.Header.Get("WWW-Authenticate"), "expired") {
|
||
|
|
t.Errorf("WWW-Authenticate missing expired reason: %q", resp.Header.Get("WWW-Authenticate"))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequireBearer_RevokedToken_401(t *testing.T) {
|
||
|
|
fx := newAuthFixture(t)
|
||
|
|
cid := fx.registerClient("https://app.example.com/cb")
|
||
|
|
tok := runFullFlow(t, fx.fixture, "https://app.example.com/cb", cid, "verifier-auth-rev")
|
||
|
|
|
||
|
|
// Revoke through the public /oauth/revoke endpoint.
|
||
|
|
form := strings.NewReader("token=" + tok.AccessToken + "&token_type_hint=access_token")
|
||
|
|
revResp, err := http.Post(fx.httpServer.URL+"/oauth/revoke", "application/x-www-form-urlencoded", form)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("revoke: %v", err)
|
||
|
|
}
|
||
|
|
revResp.Body.Close()
|
||
|
|
|
||
|
|
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||
|
|
t.Cleanup(srv.Close)
|
||
|
|
|
||
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||
|
|
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||
|
|
resp, err := http.DefaultClient.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("do: %v", err)
|
||
|
|
}
|
||
|
|
resp.Body.Close()
|
||
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
||
|
|
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||
|
|
}
|
||
|
|
if !strings.Contains(resp.Header.Get("WWW-Authenticate"), "revoked") {
|
||
|
|
t.Errorf("WWW-Authenticate missing revoked reason: %q", resp.Header.Get("WWW-Authenticate"))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSessionFromContext_NotPresent(t *testing.T) {
|
||
|
|
if _, ok := oauth.SessionFromContext(t.Context()); ok {
|
||
|
|
t.Error("SessionFromContext should return false on a context with no session attached")
|
||
|
|
}
|
||
|
|
}
|