forgejo-mcp-broker/internal/store/migrate.go

158 lines
4.5 KiB
Go
Raw Normal View History

package store
import (
"context"
"database/sql"
"embed"
"errors"
"fmt"
"io/fs"
"sort"
"strconv"
"strings"
)
//go:embed migrations/*.sql
var migrationsFS embed.FS
// migration is a parsed migration file: an integer version extracted from
// the filename (e.g. "0001_initial.sql" → 1) plus the raw SQL body.
type migration struct {
version int
name string
sql string
}
// migrate brings the database up to the latest migration found in fsys
// under dir. The schema_migrations table tracks applied versions so
// re-running is a no-op. Each migration runs inside its own transaction:
// partial progress is never committed.
//
// fsys and dir are parameters (rather than using the package-level embedded
// FS directly) so tests can inject synthetic migration sets.
func (s *Store) migrate(ctx context.Context, fsys fs.FS, dir string) error {
if _, err := s.db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
name TEXT NOT NULL,
applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
)`); err != nil {
return fmt.Errorf("create schema_migrations: %w", err)
}
applied, err := loadAppliedVersions(ctx, s.db)
if err != nil {
return err
}
migrations, err := loadMigrations(fsys, dir)
if err != nil {
return err
}
for _, m := range migrations {
if applied[m.version] {
continue
}
if err := applyMigration(ctx, s.db, m); err != nil {
return fmt.Errorf("apply migration %s: %w", m.name, err)
}
}
return nil
}
func loadAppliedVersions(ctx context.Context, db *sql.DB) (map[int]bool, error) {
rows, err := db.QueryContext(ctx, `SELECT version FROM schema_migrations`)
if err != nil {
return nil, fmt.Errorf("read schema_migrations: %w", err)
}
defer rows.Close()
applied := make(map[int]bool)
for rows.Next() {
var v int
if err := rows.Scan(&v); err != nil {
return nil, fmt.Errorf("scan schema_migrations: %w", err)
}
applied[v] = true
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iter schema_migrations: %w", err)
}
return applied, nil
}
// loadMigrations reads every *.sql file from dir inside fsys, parses its
// version prefix, and returns the migrations sorted ascending by version.
// Taking an fs.FS makes the function trivially testable with fstest.MapFS.
func loadMigrations(fsys fs.FS, dir string) ([]migration, error) {
entries, err := fs.ReadDir(fsys, dir)
if err != nil {
return nil, fmt.Errorf("read migrations dir %q: %w", dir, err)
}
var out []migration
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") {
continue
}
v, err := parseMigrationVersion(e.Name())
if err != nil {
return nil, err
}
body, err := fs.ReadFile(fsys, dir+"/"+e.Name())
if err != nil {
return nil, fmt.Errorf("read migration %s: %w", e.Name(), err)
}
out = append(out, migration{version: v, name: e.Name(), sql: string(body)})
}
sort.Slice(out, func(i, j int) bool { return out[i].version < out[j].version })
// Sanity: reject duplicate versions — a typo in filenames could otherwise
// silently skip a migration.
for i := 1; i < len(out); i++ {
if out[i].version == out[i-1].version {
return nil, fmt.Errorf("duplicate migration version %d in %s and %s",
out[i].version, out[i-1].name, out[i].name)
}
}
return out, nil
}
func parseMigrationVersion(name string) (int, error) {
// Expect "NNNN_name.sql" — strip the suffix and take the leading numeric
// prefix. Be strict so typos aren't silently accepted.
base := strings.TrimSuffix(name, ".sql")
underscore := strings.IndexByte(base, '_')
if underscore <= 0 {
return 0, fmt.Errorf("migration %q: expected NNNN_name.sql", name)
}
v, err := strconv.Atoi(base[:underscore])
if err != nil {
return 0, fmt.Errorf("migration %q: non-numeric version: %w", name, err)
}
if v <= 0 {
return 0, fmt.Errorf("migration %q: version must be > 0", name)
}
return v, nil
}
func applyMigration(ctx context.Context, db *sql.DB, m migration) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin: %w", err)
}
if _, err := tx.ExecContext(ctx, m.sql); err != nil {
return errors.Join(fmt.Errorf("exec: %w", err), tx.Rollback())
}
if _, err := tx.ExecContext(ctx,
`INSERT INTO schema_migrations (version, name) VALUES (?, ?)`,
m.version, m.name); err != nil {
return errors.Join(fmt.Errorf("record: %w", err), tx.Rollback())
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit: %w", err)
}
return nil
}