package store import ( "context" "database/sql" "embed" "errors" "fmt" "io/fs" "sort" "strconv" "strings" ) //go:embed migrations/*.sql var migrationsFS embed.FS // migration is a parsed migration file: an integer version extracted from // the filename (e.g. "0001_initial.sql" → 1) plus the raw SQL body. type migration struct { version int name string sql string } // migrate brings the database up to the latest migration found in fsys // under dir. The schema_migrations table tracks applied versions so // re-running is a no-op. Each migration runs inside its own transaction: // partial progress is never committed. // // fsys and dir are parameters (rather than using the package-level embedded // FS directly) so tests can inject synthetic migration sets. func (s *Store) migrate(ctx context.Context, fsys fs.FS, dir string) error { if _, err := s.db.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, name TEXT NOT NULL, applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) )`); err != nil { return fmt.Errorf("create schema_migrations: %w", err) } applied, err := loadAppliedVersions(ctx, s.db) if err != nil { return err } migrations, err := loadMigrations(fsys, dir) if err != nil { return err } for _, m := range migrations { if applied[m.version] { continue } if err := applyMigration(ctx, s.db, m); err != nil { return fmt.Errorf("apply migration %s: %w", m.name, err) } } return nil } func loadAppliedVersions(ctx context.Context, db *sql.DB) (map[int]bool, error) { rows, err := db.QueryContext(ctx, `SELECT version FROM schema_migrations`) if err != nil { return nil, fmt.Errorf("read schema_migrations: %w", err) } defer rows.Close() applied := make(map[int]bool) for rows.Next() { var v int if err := rows.Scan(&v); err != nil { return nil, fmt.Errorf("scan schema_migrations: %w", err) } applied[v] = true } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iter schema_migrations: %w", err) } return applied, nil } // loadMigrations reads every *.sql file from dir inside fsys, parses its // version prefix, and returns the migrations sorted ascending by version. // Taking an fs.FS makes the function trivially testable with fstest.MapFS. func loadMigrations(fsys fs.FS, dir string) ([]migration, error) { entries, err := fs.ReadDir(fsys, dir) if err != nil { return nil, fmt.Errorf("read migrations dir %q: %w", dir, err) } var out []migration for _, e := range entries { if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") { continue } v, err := parseMigrationVersion(e.Name()) if err != nil { return nil, err } body, err := fs.ReadFile(fsys, dir+"/"+e.Name()) if err != nil { return nil, fmt.Errorf("read migration %s: %w", e.Name(), err) } out = append(out, migration{version: v, name: e.Name(), sql: string(body)}) } sort.Slice(out, func(i, j int) bool { return out[i].version < out[j].version }) // Sanity: reject duplicate versions — a typo in filenames could otherwise // silently skip a migration. for i := 1; i < len(out); i++ { if out[i].version == out[i-1].version { return nil, fmt.Errorf("duplicate migration version %d in %s and %s", out[i].version, out[i-1].name, out[i].name) } } return out, nil } func parseMigrationVersion(name string) (int, error) { // Expect "NNNN_name.sql" — strip the suffix and take the leading numeric // prefix. Be strict so typos aren't silently accepted. base := strings.TrimSuffix(name, ".sql") underscore := strings.IndexByte(base, '_') if underscore <= 0 { return 0, fmt.Errorf("migration %q: expected NNNN_name.sql", name) } v, err := strconv.Atoi(base[:underscore]) if err != nil { return 0, fmt.Errorf("migration %q: non-numeric version: %w", name, err) } if v <= 0 { return 0, fmt.Errorf("migration %q: version must be > 0", name) } return v, nil } func applyMigration(ctx context.Context, db *sql.DB, m migration) error { tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin: %w", err) } if _, err := tx.ExecContext(ctx, m.sql); err != nil { return errors.Join(fmt.Errorf("exec: %w", err), tx.Rollback()) } if _, err := tx.ExecContext(ctx, `INSERT INTO schema_migrations (version, name) VALUES (?, ?)`, m.version, m.name); err != nil { return errors.Join(fmt.Errorf("record: %w", err), tx.Rollback()) } if err := tx.Commit(); err != nil { return fmt.Errorf("commit: %w", err) } return nil }