feat(oauth): bearer-token middleware (forgejo-mcp-broker-ytw)
Adds Authenticator.RequireBearer — http middleware that gates downstream
handlers on a valid broker access token.
Lookup path:
1. Read Authorization: Bearer <token> header.
2. SHA-256 the token, query access_tokens by token_hash.
3. Reject expired or revoked rows.
4. Build a Session (client_id, forgejo user info, upstream token,
scopes) and attach to r.Context() under a typed key.
Downstream handlers (the MCP endpoint shipping in 5a) read the
upstream Forgejo token via SessionFromContext to spawn forgejo-mcp
subprocesses scoped to the right user.
Failures emit 401 with an RFC 6750 §3 WWW-Authenticate header carrying
distinct error codes (invalid_request for missing/malformed headers,
invalid_token with reason=expired/revoked/unknown for token problems).
The body stays empty so a confused browser doesn't render auth errors;
all detail rides in the header where compliant clients look for it.
Tests: 90.9% on RequireBearer, 91.7% on lookupSession. Covers valid
token, missing/wrong-scheme/empty Authorization, unknown token,
expired token (clock-advanced past AccessTokenTTL), revoked token (via
the public /oauth/revoke endpoint).
Closes forgejo-mcp-broker-ytw.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
fee12a2ac0
commit
9c8cf40501
3 changed files with 375 additions and 2 deletions
165
internal/oauth/auth.go
Normal file
165
internal/oauth/auth.go
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/store"
|
||||
)
|
||||
|
||||
// Session is the per-request OAuth context attached by the bearer
|
||||
// middleware. Downstream handlers (the MCP endpoint, in phase 5a) read
|
||||
// the upstream Forgejo token from here to spawn forgejo-mcp subprocesses
|
||||
// scoped to the right user.
|
||||
type Session struct {
|
||||
ClientID string
|
||||
ForgejoUserID int64
|
||||
ForgejoUsername string
|
||||
Scopes string
|
||||
BrokerTokenHash string // SHA-256 hex of the broker token; for log correlation
|
||||
ForgejoToken string // plaintext upstream token — keep in memory, never log
|
||||
ForgejoRefresh string
|
||||
ForgejoTokenExp time.Time
|
||||
}
|
||||
|
||||
// Authenticator resolves Bearer tokens against the access_tokens table.
|
||||
// Use Authenticator.RequireBearer to wrap the protected handler.
|
||||
type Authenticator struct {
|
||||
Store *store.Store
|
||||
Now func() time.Time // optional; defaults to time.Now
|
||||
}
|
||||
|
||||
type sessionCtxKey struct{}
|
||||
|
||||
// SessionFromContext returns the Session attached by RequireBearer, if any.
|
||||
func SessionFromContext(ctx context.Context) (*Session, bool) {
|
||||
s, ok := ctx.Value(sessionCtxKey{}).(*Session)
|
||||
return s, ok
|
||||
}
|
||||
|
||||
// RequireBearer is HTTP middleware that:
|
||||
// 1. Demands an `Authorization: Bearer <token>` header.
|
||||
// 2. Looks the token up by SHA-256 hash in access_tokens.
|
||||
// 3. Rejects expired or revoked tokens.
|
||||
// 4. Attaches the resolved Session to the request context for downstream
|
||||
// handlers to read via SessionFromContext.
|
||||
//
|
||||
// Failures emit a 401 with an RFC 6750 §3 WWW-Authenticate header carrying
|
||||
// the appropriate error code (invalid_token / invalid_request).
|
||||
func (a *Authenticator) RequireBearer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
now := a.Now
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
|
||||
raw := r.Header.Get("Authorization")
|
||||
if raw == "" {
|
||||
respondAuthError(w, "invalid_request", "missing Authorization header")
|
||||
return
|
||||
}
|
||||
token, ok := strings.CutPrefix(raw, "Bearer ")
|
||||
if !ok || token == "" {
|
||||
respondAuthError(w, "invalid_request", "Authorization header must use Bearer scheme")
|
||||
return
|
||||
}
|
||||
|
||||
sess, err := a.lookupSession(r.Context(), hashToken(token), now())
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errTokenNotFound):
|
||||
respondAuthError(w, "invalid_token", "unknown token")
|
||||
case errors.Is(err, errTokenExpired):
|
||||
respondAuthError(w, "invalid_token", "token expired")
|
||||
case errors.Is(err, errTokenRevoked):
|
||||
respondAuthError(w, "invalid_token", "token revoked")
|
||||
default:
|
||||
// Unexpected DB or scan error: don't leak internals to
|
||||
// the caller. Logging would land in middleware-of-the-
|
||||
// future once we wire a logger here.
|
||||
respondAuthError(w, "invalid_token", "auth lookup failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), sessionCtxKey{}, sess)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// Sentinel errors so RequireBearer can render distinct WWW-Authenticate
|
||||
// reasons for the operator while always returning 401 to the client.
|
||||
var (
|
||||
errTokenNotFound = errors.New("token not found")
|
||||
errTokenExpired = errors.New("token expired")
|
||||
errTokenRevoked = errors.New("token revoked")
|
||||
)
|
||||
|
||||
func (a *Authenticator) lookupSession(ctx context.Context, tokenHash string, now time.Time) (*Session, error) {
|
||||
var (
|
||||
clientID, fjUsername, scopes, fjAccess, fjRefresh string
|
||||
fjUserID int64
|
||||
expiresAt, fjExpiresAt int64
|
||||
revokedAt sql.NullInt64
|
||||
)
|
||||
row := a.Store.DB().QueryRowContext(ctx,
|
||||
`SELECT client_id, forgejo_user_id, forgejo_username, scopes,
|
||||
forgejo_access_token, forgejo_refresh_token, forgejo_token_expires_at,
|
||||
expires_at, revoked_at
|
||||
FROM access_tokens WHERE token_hash = ?`, tokenHash)
|
||||
err := row.Scan(&clientID, &fjUserID, &fjUsername, &scopes,
|
||||
&fjAccess, &fjRefresh, &fjExpiresAt, &expiresAt, &revokedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, errTokenNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if revokedAt.Valid {
|
||||
return nil, errTokenRevoked
|
||||
}
|
||||
if now.Unix() > expiresAt {
|
||||
return nil, errTokenExpired
|
||||
}
|
||||
|
||||
return &Session{
|
||||
ClientID: clientID,
|
||||
ForgejoUserID: fjUserID,
|
||||
ForgejoUsername: fjUsername,
|
||||
Scopes: scopes,
|
||||
BrokerTokenHash: tokenHash,
|
||||
ForgejoToken: fjAccess,
|
||||
ForgejoRefresh: fjRefresh,
|
||||
ForgejoTokenExp: time.Unix(fjExpiresAt, 0).UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// respondAuthError writes a 401 with a WWW-Authenticate header per RFC 6750
|
||||
// §3. The body stays empty — error info goes in the header so it's discoverable
|
||||
// to compliant clients without leaking detail in a body that browsers might
|
||||
// render.
|
||||
func respondAuthError(w http.ResponseWriter, errorCode, description string) {
|
||||
w.Header().Set("WWW-Authenticate",
|
||||
fmt.Sprintf(`Bearer error="%s", error_description="%s"`,
|
||||
escapeHeader(errorCode), escapeHeader(description)))
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// escapeHeader strips characters that would break a quoted-string in an
|
||||
// HTTP header value. Conservative: only allow safe ASCII. The error codes
|
||||
// we emit are well-known constants, so this is a defense-in-depth check
|
||||
// against a future bug accidentally interpolating user input.
|
||||
func escapeHeader(s string) string {
|
||||
var b strings.Builder
|
||||
for _, c := range s {
|
||||
if c >= 0x20 && c < 0x7f && c != '"' && c != '\\' {
|
||||
b.WriteRune(c)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
208
internal/oauth/auth_test.go
Normal file
208
internal/oauth/auth_test.go
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
package oauth_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/oauth"
|
||||
)
|
||||
|
||||
// authFixture wraps the OAuth fixture and exposes a fresh Authenticator
|
||||
// pointed at the same store and clock.
|
||||
type authFixture struct {
|
||||
*fixture
|
||||
auth *oauth.Authenticator
|
||||
}
|
||||
|
||||
func newAuthFixture(t *testing.T) *authFixture {
|
||||
t.Helper()
|
||||
fx := newFixture(t)
|
||||
return &authFixture{
|
||||
fixture: fx,
|
||||
auth: &oauth.Authenticator{Store: fx.store, Now: fx.now},
|
||||
}
|
||||
}
|
||||
|
||||
// echoHandler reads the Session from context and writes a recognisable
|
||||
// payload, so tests can confirm the right Session reached the handler.
|
||||
var echoHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, ok := oauth.SessionFromContext(r.Context())
|
||||
if !ok {
|
||||
http.Error(w, "no session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, _ = io.WriteString(w, "ok username="+sess.ForgejoUsername+
|
||||
" client="+sess.ClientID+
|
||||
" forgejo_token="+sess.ForgejoToken)
|
||||
})
|
||||
|
||||
func TestRequireBearer_ValidTokenPasses(t *testing.T) {
|
||||
fx := newAuthFixture(t)
|
||||
cid := fx.registerClient("https://app.example.com/cb")
|
||||
tok := runFullFlow(t, fx.fixture, "https://app.example.com/cb", cid, "verifier-auth-1")
|
||||
|
||||
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("status = %d, want 200; body: %s", resp.StatusCode, body)
|
||||
}
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if !strings.Contains(string(body), "username=alice") {
|
||||
t.Errorf("session not surfaced correctly: %s", body)
|
||||
}
|
||||
if !strings.Contains(string(body), "forgejo_token=fj-access") {
|
||||
t.Errorf("forgejo token not in session: %s", body)
|
||||
}
|
||||
if !strings.Contains(string(body), "client="+cid) {
|
||||
t.Errorf("client_id not in session: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireBearer_NoHeader_401(t *testing.T) {
|
||||
fx := newAuthFixture(t)
|
||||
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
resp, err := http.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||||
}
|
||||
if got := resp.Header.Get("WWW-Authenticate"); !strings.Contains(got, "invalid_request") {
|
||||
t.Errorf("WWW-Authenticate = %q, want invalid_request", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireBearer_WrongScheme_401(t *testing.T) {
|
||||
fx := newAuthFixture(t)
|
||||
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||||
req.Header.Set("Authorization", "Basic abc==")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireBearer_EmptyToken_401(t *testing.T) {
|
||||
fx := newAuthFixture(t)
|
||||
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||||
req.Header.Set("Authorization", "Bearer ")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireBearer_UnknownToken_401(t *testing.T) {
|
||||
fx := newAuthFixture(t)
|
||||
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||||
req.Header.Set("Authorization", "Bearer made-up-not-in-store")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||||
}
|
||||
if got := resp.Header.Get("WWW-Authenticate"); !strings.Contains(got, "invalid_token") {
|
||||
t.Errorf("WWW-Authenticate = %q, want invalid_token", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireBearer_ExpiredToken_401(t *testing.T) {
|
||||
fx := newAuthFixture(t)
|
||||
cid := fx.registerClient("https://app.example.com/cb")
|
||||
tok := runFullFlow(t, fx.fixture, "https://app.example.com/cb", cid, "verifier-auth-exp")
|
||||
|
||||
// Push the clock past the access-token lifetime.
|
||||
fx.advance(oauth.AccessTokenTTL + time.Minute)
|
||||
|
||||
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||||
}
|
||||
if !strings.Contains(resp.Header.Get("WWW-Authenticate"), "expired") {
|
||||
t.Errorf("WWW-Authenticate missing expired reason: %q", resp.Header.Get("WWW-Authenticate"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireBearer_RevokedToken_401(t *testing.T) {
|
||||
fx := newAuthFixture(t)
|
||||
cid := fx.registerClient("https://app.example.com/cb")
|
||||
tok := runFullFlow(t, fx.fixture, "https://app.example.com/cb", cid, "verifier-auth-rev")
|
||||
|
||||
// Revoke through the public /oauth/revoke endpoint.
|
||||
form := strings.NewReader("token=" + tok.AccessToken + "&token_type_hint=access_token")
|
||||
revResp, err := http.Post(fx.httpServer.URL+"/oauth/revoke", "application/x-www-form-urlencoded", form)
|
||||
if err != nil {
|
||||
t.Fatalf("revoke: %v", err)
|
||||
}
|
||||
revResp.Body.Close()
|
||||
|
||||
srv := httptest.NewServer(fx.auth.RequireBearer(echoHandler))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", resp.StatusCode)
|
||||
}
|
||||
if !strings.Contains(resp.Header.Get("WWW-Authenticate"), "revoked") {
|
||||
t.Errorf("WWW-Authenticate missing revoked reason: %q", resp.Header.Get("WWW-Authenticate"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionFromContext_NotPresent(t *testing.T) {
|
||||
if _, ok := oauth.SessionFromContext(t.Context()); ok {
|
||||
t.Error("SessionFromContext should return false on a context with no session attached")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue