248 lines
6.9 KiB
Go
248 lines
6.9 KiB
Go
|
|
package store_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"sort"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"kode.naiv.no/olemd/forgejo-mcp-broker/internal/store"
|
||
|
|
)
|
||
|
|
|
||
|
|
// expectedSchema lists every column we expect each OAuth table to have. The
|
||
|
|
// test sorts both sides before comparison, so column order in 0002 is free
|
||
|
|
// to change without breaking tests as long as the set is stable.
|
||
|
|
var expectedSchema = map[string][]string{
|
||
|
|
"clients": {
|
||
|
|
"client_id", "client_secret", "redirect_uris",
|
||
|
|
"metadata_json", "created_at", "last_used",
|
||
|
|
},
|
||
|
|
"auth_codes": {
|
||
|
|
"code", "client_id", "redirect_uri",
|
||
|
|
"code_challenge", "code_challenge_method", "scopes",
|
||
|
|
"forgejo_access_token", "forgejo_refresh_token", "forgejo_token_expires_at",
|
||
|
|
"forgejo_user_id", "forgejo_username",
|
||
|
|
"expires_at", "used_at",
|
||
|
|
},
|
||
|
|
"access_tokens": {
|
||
|
|
"token_hash", "client_id",
|
||
|
|
"forgejo_user_id", "forgejo_username", "scopes",
|
||
|
|
"forgejo_access_token", "forgejo_refresh_token", "forgejo_token_expires_at",
|
||
|
|
"expires_at", "created_at", "revoked_at",
|
||
|
|
},
|
||
|
|
"refresh_tokens": {
|
||
|
|
"token_hash", "access_token_hash", "client_id",
|
||
|
|
"expires_at", "created_at", "revoked_at",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOAuthSchema_TablesAndColumns(t *testing.T) {
|
||
|
|
ctx := t.Context()
|
||
|
|
s, err := store.Open(ctx, tempStorePath(t))
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Open: %v", err)
|
||
|
|
}
|
||
|
|
defer s.Close()
|
||
|
|
|
||
|
|
for table, want := range expectedSchema {
|
||
|
|
got := tableColumns(t, ctx, s, table)
|
||
|
|
sort.Strings(got)
|
||
|
|
w := append([]string(nil), want...)
|
||
|
|
sort.Strings(w)
|
||
|
|
if !equalStrings(got, w) {
|
||
|
|
t.Errorf("table %q columns = %v, want %v", table, got, w)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOAuthSchema_OAuthSchemaVersionRecorded(t *testing.T) {
|
||
|
|
ctx := t.Context()
|
||
|
|
s, err := store.Open(ctx, tempStorePath(t))
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Open: %v", err)
|
||
|
|
}
|
||
|
|
defer s.Close()
|
||
|
|
|
||
|
|
var v string
|
||
|
|
if err := s.DB().QueryRowContext(ctx,
|
||
|
|
`SELECT value FROM broker_meta WHERE key = 'oauth_schema_version'`,
|
||
|
|
).Scan(&v); err != nil {
|
||
|
|
t.Fatalf("read oauth_schema_version: %v", err)
|
||
|
|
}
|
||
|
|
if v != "1" {
|
||
|
|
t.Errorf("oauth_schema_version = %q, want %q", v, "1")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOAuthSchema_ForeignKeyCascade(t *testing.T) {
|
||
|
|
// Foreign keys are enforced via the DSN pragma set in store.buildDSN.
|
||
|
|
// Verify the cascade works end-to-end: deleting a client tears down its
|
||
|
|
// dependent auth_codes, access_tokens, and refresh_tokens rows.
|
||
|
|
ctx := t.Context()
|
||
|
|
s, err := store.Open(ctx, tempStorePath(t))
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Open: %v", err)
|
||
|
|
}
|
||
|
|
defer s.Close()
|
||
|
|
|
||
|
|
mustExec(t, ctx, s,
|
||
|
|
`INSERT INTO clients (client_id, redirect_uris, created_at)
|
||
|
|
VALUES ('c1', '["https://example.com/cb"]', 1000)`)
|
||
|
|
|
||
|
|
mustExec(t, ctx, s,
|
||
|
|
`INSERT INTO auth_codes
|
||
|
|
(code, client_id, redirect_uri, code_challenge, code_challenge_method,
|
||
|
|
scopes, forgejo_access_token, forgejo_token_expires_at,
|
||
|
|
forgejo_user_id, forgejo_username, expires_at)
|
||
|
|
VALUES
|
||
|
|
('code-1', 'c1', 'https://example.com/cb', 'chal', 'S256',
|
||
|
|
'read:user', 'forgejo-tok', 9999, 42, 'alice', 1600)`)
|
||
|
|
|
||
|
|
mustExec(t, ctx, s,
|
||
|
|
`INSERT INTO access_tokens
|
||
|
|
(token_hash, client_id, forgejo_user_id, forgejo_username, scopes,
|
||
|
|
forgejo_access_token, forgejo_token_expires_at,
|
||
|
|
expires_at, created_at)
|
||
|
|
VALUES
|
||
|
|
('hash-a', 'c1', 42, 'alice', 'read:user',
|
||
|
|
'forgejo-tok', 9999, 9999, 1000)`)
|
||
|
|
|
||
|
|
mustExec(t, ctx, s,
|
||
|
|
`INSERT INTO refresh_tokens
|
||
|
|
(token_hash, access_token_hash, client_id, expires_at, created_at)
|
||
|
|
VALUES
|
||
|
|
('hash-r', 'hash-a', 'c1', 99999, 1000)`)
|
||
|
|
|
||
|
|
// Deleting the client must cascade to everything else.
|
||
|
|
mustExec(t, ctx, s, `DELETE FROM clients WHERE client_id = 'c1'`)
|
||
|
|
|
||
|
|
for _, table := range []string{"auth_codes", "access_tokens", "refresh_tokens"} {
|
||
|
|
var n int
|
||
|
|
row := s.DB().QueryRowContext(ctx, "SELECT COUNT(*) FROM "+table)
|
||
|
|
if err := row.Scan(&n); err != nil {
|
||
|
|
t.Fatalf("count %s: %v", table, err)
|
||
|
|
}
|
||
|
|
if n != 0 {
|
||
|
|
t.Errorf("%s row count after cascade delete = %d, want 0", table, n)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOAuthSchema_RefreshTokenCascadeFromAccessToken(t *testing.T) {
|
||
|
|
// access_tokens -> refresh_tokens cascade is independent of the
|
||
|
|
// client-level cascade. A refresh token outliving its access token is a
|
||
|
|
// bug class to prevent at the schema level.
|
||
|
|
ctx := t.Context()
|
||
|
|
s, err := store.Open(ctx, tempStorePath(t))
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Open: %v", err)
|
||
|
|
}
|
||
|
|
defer s.Close()
|
||
|
|
|
||
|
|
mustExec(t, ctx, s,
|
||
|
|
`INSERT INTO clients (client_id, redirect_uris, created_at)
|
||
|
|
VALUES ('c1', '[]', 1000)`)
|
||
|
|
mustExec(t, ctx, s,
|
||
|
|
`INSERT INTO access_tokens
|
||
|
|
(token_hash, client_id, forgejo_user_id, forgejo_username, scopes,
|
||
|
|
forgejo_access_token, forgejo_token_expires_at, expires_at, created_at)
|
||
|
|
VALUES
|
||
|
|
('hash-a', 'c1', 42, 'alice', '', 'forgejo-tok', 9999, 9999, 1000)`)
|
||
|
|
mustExec(t, ctx, s,
|
||
|
|
`INSERT INTO refresh_tokens
|
||
|
|
(token_hash, access_token_hash, client_id, expires_at, created_at)
|
||
|
|
VALUES
|
||
|
|
('hash-r', 'hash-a', 'c1', 99999, 1000)`)
|
||
|
|
|
||
|
|
mustExec(t, ctx, s, `DELETE FROM access_tokens WHERE token_hash = 'hash-a'`)
|
||
|
|
|
||
|
|
var n int
|
||
|
|
if err := s.DB().QueryRowContext(ctx,
|
||
|
|
`SELECT COUNT(*) FROM refresh_tokens`).Scan(&n); err != nil {
|
||
|
|
t.Fatalf("count refresh_tokens: %v", err)
|
||
|
|
}
|
||
|
|
if n != 0 {
|
||
|
|
t.Errorf("refresh_tokens count after access_token delete = %d, want 0", n)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOAuthSchema_IndexesPresent(t *testing.T) {
|
||
|
|
wantIndexes := []string{
|
||
|
|
"idx_clients_last_used",
|
||
|
|
"idx_auth_codes_expires_at",
|
||
|
|
"idx_access_tokens_expires_at",
|
||
|
|
"idx_access_tokens_forgejo_uid",
|
||
|
|
"idx_refresh_tokens_expires_at",
|
||
|
|
}
|
||
|
|
|
||
|
|
ctx := t.Context()
|
||
|
|
s, err := store.Open(ctx, tempStorePath(t))
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Open: %v", err)
|
||
|
|
}
|
||
|
|
defer s.Close()
|
||
|
|
|
||
|
|
rows, err := s.DB().QueryContext(ctx,
|
||
|
|
`SELECT name FROM sqlite_master WHERE type = 'index' AND name LIKE 'idx_%'`)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("query indexes: %v", err)
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
got := map[string]bool{}
|
||
|
|
for rows.Next() {
|
||
|
|
var name string
|
||
|
|
if err := rows.Scan(&name); err != nil {
|
||
|
|
t.Fatalf("scan: %v", err)
|
||
|
|
}
|
||
|
|
got[name] = true
|
||
|
|
}
|
||
|
|
for _, idx := range wantIndexes {
|
||
|
|
if !got[idx] {
|
||
|
|
t.Errorf("missing index %q", idx)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func tableColumns(t *testing.T, ctx context.Context, s *store.Store, table string) []string {
|
||
|
|
t.Helper()
|
||
|
|
rows, err := s.DB().QueryContext(ctx, "PRAGMA table_info("+table+")")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("PRAGMA table_info(%s): %v", table, err)
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
var cols []string
|
||
|
|
for rows.Next() {
|
||
|
|
// PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk
|
||
|
|
var cid int
|
||
|
|
var name, ctype string
|
||
|
|
var notnull, pk int
|
||
|
|
var dflt any
|
||
|
|
if err := rows.Scan(&cid, &name, &ctype, ¬null, &dflt, &pk); err != nil {
|
||
|
|
t.Fatalf("scan column: %v", err)
|
||
|
|
}
|
||
|
|
cols = append(cols, name)
|
||
|
|
}
|
||
|
|
if len(cols) == 0 {
|
||
|
|
t.Fatalf("table %q has no columns (does it exist?)", table)
|
||
|
|
}
|
||
|
|
return cols
|
||
|
|
}
|
||
|
|
|
||
|
|
func mustExec(t *testing.T, ctx context.Context, s *store.Store, query string, args ...any) {
|
||
|
|
t.Helper()
|
||
|
|
if _, err := s.DB().ExecContext(ctx, query, args...); err != nil {
|
||
|
|
t.Fatalf("exec %q: %v", query, err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func equalStrings(a, b []string) bool {
|
||
|
|
if len(a) != len(b) {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
for i := range a {
|
||
|
|
if a[i] != b[i] {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return true
|
||
|
|
}
|