package store_test import ( "context" "sort" "testing" "kode.naiv.no/olemd/forgejo-mcp-broker/internal/store" ) // expectedSchema lists every column we expect each OAuth table to have. The // test sorts both sides before comparison, so column order in 0002 is free // to change without breaking tests as long as the set is stable. var expectedSchema = map[string][]string{ "clients": { "client_id", "client_secret", "redirect_uris", "metadata_json", "created_at", "last_used", }, "auth_codes": { "code", "client_id", "redirect_uri", "code_challenge", "code_challenge_method", "scopes", "forgejo_access_token", "forgejo_refresh_token", "forgejo_token_expires_at", "forgejo_user_id", "forgejo_username", "expires_at", "used_at", }, "access_tokens": { "token_hash", "client_id", "forgejo_user_id", "forgejo_username", "scopes", "forgejo_access_token", "forgejo_refresh_token", "forgejo_token_expires_at", "expires_at", "created_at", "revoked_at", }, "refresh_tokens": { "token_hash", "access_token_hash", "client_id", "expires_at", "created_at", "revoked_at", }, } func TestOAuthSchema_TablesAndColumns(t *testing.T) { ctx := t.Context() s, err := store.Open(ctx, tempStorePath(t)) if err != nil { t.Fatalf("Open: %v", err) } defer s.Close() for table, want := range expectedSchema { got := tableColumns(t, ctx, s, table) sort.Strings(got) w := append([]string(nil), want...) sort.Strings(w) if !equalStrings(got, w) { t.Errorf("table %q columns = %v, want %v", table, got, w) } } } func TestOAuthSchema_OAuthSchemaVersionRecorded(t *testing.T) { ctx := t.Context() s, err := store.Open(ctx, tempStorePath(t)) if err != nil { t.Fatalf("Open: %v", err) } defer s.Close() var v string if err := s.DB().QueryRowContext(ctx, `SELECT value FROM broker_meta WHERE key = 'oauth_schema_version'`, ).Scan(&v); err != nil { t.Fatalf("read oauth_schema_version: %v", err) } if v != "1" { t.Errorf("oauth_schema_version = %q, want %q", v, "1") } } func TestOAuthSchema_ForeignKeyCascade(t *testing.T) { // Foreign keys are enforced via the DSN pragma set in store.buildDSN. // Verify the cascade works end-to-end: deleting a client tears down its // dependent auth_codes, access_tokens, and refresh_tokens rows. ctx := t.Context() s, err := store.Open(ctx, tempStorePath(t)) if err != nil { t.Fatalf("Open: %v", err) } defer s.Close() mustExec(t, ctx, s, `INSERT INTO clients (client_id, redirect_uris, created_at) VALUES ('c1', '["https://example.com/cb"]', 1000)`) mustExec(t, ctx, s, `INSERT INTO auth_codes (code, client_id, redirect_uri, code_challenge, code_challenge_method, scopes, forgejo_access_token, forgejo_token_expires_at, forgejo_user_id, forgejo_username, expires_at) VALUES ('code-1', 'c1', 'https://example.com/cb', 'chal', 'S256', 'read:user', 'forgejo-tok', 9999, 42, 'alice', 1600)`) mustExec(t, ctx, s, `INSERT INTO access_tokens (token_hash, client_id, forgejo_user_id, forgejo_username, scopes, forgejo_access_token, forgejo_token_expires_at, expires_at, created_at) VALUES ('hash-a', 'c1', 42, 'alice', 'read:user', 'forgejo-tok', 9999, 9999, 1000)`) mustExec(t, ctx, s, `INSERT INTO refresh_tokens (token_hash, access_token_hash, client_id, expires_at, created_at) VALUES ('hash-r', 'hash-a', 'c1', 99999, 1000)`) // Deleting the client must cascade to everything else. mustExec(t, ctx, s, `DELETE FROM clients WHERE client_id = 'c1'`) for _, table := range []string{"auth_codes", "access_tokens", "refresh_tokens"} { var n int row := s.DB().QueryRowContext(ctx, "SELECT COUNT(*) FROM "+table) if err := row.Scan(&n); err != nil { t.Fatalf("count %s: %v", table, err) } if n != 0 { t.Errorf("%s row count after cascade delete = %d, want 0", table, n) } } } func TestOAuthSchema_RefreshTokenCascadeFromAccessToken(t *testing.T) { // access_tokens -> refresh_tokens cascade is independent of the // client-level cascade. A refresh token outliving its access token is a // bug class to prevent at the schema level. ctx := t.Context() s, err := store.Open(ctx, tempStorePath(t)) if err != nil { t.Fatalf("Open: %v", err) } defer s.Close() mustExec(t, ctx, s, `INSERT INTO clients (client_id, redirect_uris, created_at) VALUES ('c1', '[]', 1000)`) mustExec(t, ctx, s, `INSERT INTO access_tokens (token_hash, client_id, forgejo_user_id, forgejo_username, scopes, forgejo_access_token, forgejo_token_expires_at, expires_at, created_at) VALUES ('hash-a', 'c1', 42, 'alice', '', 'forgejo-tok', 9999, 9999, 1000)`) mustExec(t, ctx, s, `INSERT INTO refresh_tokens (token_hash, access_token_hash, client_id, expires_at, created_at) VALUES ('hash-r', 'hash-a', 'c1', 99999, 1000)`) mustExec(t, ctx, s, `DELETE FROM access_tokens WHERE token_hash = 'hash-a'`) var n int if err := s.DB().QueryRowContext(ctx, `SELECT COUNT(*) FROM refresh_tokens`).Scan(&n); err != nil { t.Fatalf("count refresh_tokens: %v", err) } if n != 0 { t.Errorf("refresh_tokens count after access_token delete = %d, want 0", n) } } func TestOAuthSchema_IndexesPresent(t *testing.T) { wantIndexes := []string{ "idx_clients_last_used", "idx_auth_codes_expires_at", "idx_access_tokens_expires_at", "idx_access_tokens_forgejo_uid", "idx_refresh_tokens_expires_at", } ctx := t.Context() s, err := store.Open(ctx, tempStorePath(t)) if err != nil { t.Fatalf("Open: %v", err) } defer s.Close() rows, err := s.DB().QueryContext(ctx, `SELECT name FROM sqlite_master WHERE type = 'index' AND name LIKE 'idx_%'`) if err != nil { t.Fatalf("query indexes: %v", err) } defer rows.Close() got := map[string]bool{} for rows.Next() { var name string if err := rows.Scan(&name); err != nil { t.Fatalf("scan: %v", err) } got[name] = true } for _, idx := range wantIndexes { if !got[idx] { t.Errorf("missing index %q", idx) } } } func tableColumns(t *testing.T, ctx context.Context, s *store.Store, table string) []string { t.Helper() rows, err := s.DB().QueryContext(ctx, "PRAGMA table_info("+table+")") if err != nil { t.Fatalf("PRAGMA table_info(%s): %v", table, err) } defer rows.Close() var cols []string for rows.Next() { // PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk var cid int var name, ctype string var notnull, pk int var dflt any if err := rows.Scan(&cid, &name, &ctype, ¬null, &dflt, &pk); err != nil { t.Fatalf("scan column: %v", err) } cols = append(cols, name) } if len(cols) == 0 { t.Fatalf("table %q has no columns (does it exist?)", table) } return cols } func mustExec(t *testing.T, ctx context.Context, s *store.Store, query string, args ...any) { t.Helper() if _, err := s.DB().ExecContext(ctx, query, args...); err != nil { t.Fatalf("exec %q: %v", query, err) } } func equalStrings(a, b []string) bool { if len(a) != len(b) { return false } for i := range a { if a[i] != b[i] { return false } } return true }