158 lines
4.5 KiB
Go
158 lines
4.5 KiB
Go
|
|
package store
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"database/sql"
|
||
|
|
"embed"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"io/fs"
|
||
|
|
"sort"
|
||
|
|
"strconv"
|
||
|
|
"strings"
|
||
|
|
)
|
||
|
|
|
||
|
|
//go:embed migrations/*.sql
|
||
|
|
var migrationsFS embed.FS
|
||
|
|
|
||
|
|
// migration is a parsed migration file: an integer version extracted from
|
||
|
|
// the filename (e.g. "0001_initial.sql" → 1) plus the raw SQL body.
|
||
|
|
type migration struct {
|
||
|
|
version int
|
||
|
|
name string
|
||
|
|
sql string
|
||
|
|
}
|
||
|
|
|
||
|
|
// migrate brings the database up to the latest migration found in fsys
|
||
|
|
// under dir. The schema_migrations table tracks applied versions so
|
||
|
|
// re-running is a no-op. Each migration runs inside its own transaction:
|
||
|
|
// partial progress is never committed.
|
||
|
|
//
|
||
|
|
// fsys and dir are parameters (rather than using the package-level embedded
|
||
|
|
// FS directly) so tests can inject synthetic migration sets.
|
||
|
|
func (s *Store) migrate(ctx context.Context, fsys fs.FS, dir string) error {
|
||
|
|
if _, err := s.db.ExecContext(ctx, `
|
||
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||
|
|
version INTEGER PRIMARY KEY,
|
||
|
|
name TEXT NOT NULL,
|
||
|
|
applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||
|
|
)`); err != nil {
|
||
|
|
return fmt.Errorf("create schema_migrations: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
applied, err := loadAppliedVersions(ctx, s.db)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
migrations, err := loadMigrations(fsys, dir)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, m := range migrations {
|
||
|
|
if applied[m.version] {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if err := applyMigration(ctx, s.db, m); err != nil {
|
||
|
|
return fmt.Errorf("apply migration %s: %w", m.name, err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func loadAppliedVersions(ctx context.Context, db *sql.DB) (map[int]bool, error) {
|
||
|
|
rows, err := db.QueryContext(ctx, `SELECT version FROM schema_migrations`)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("read schema_migrations: %w", err)
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
applied := make(map[int]bool)
|
||
|
|
for rows.Next() {
|
||
|
|
var v int
|
||
|
|
if err := rows.Scan(&v); err != nil {
|
||
|
|
return nil, fmt.Errorf("scan schema_migrations: %w", err)
|
||
|
|
}
|
||
|
|
applied[v] = true
|
||
|
|
}
|
||
|
|
if err := rows.Err(); err != nil {
|
||
|
|
return nil, fmt.Errorf("iter schema_migrations: %w", err)
|
||
|
|
}
|
||
|
|
return applied, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// loadMigrations reads every *.sql file from dir inside fsys, parses its
|
||
|
|
// version prefix, and returns the migrations sorted ascending by version.
|
||
|
|
// Taking an fs.FS makes the function trivially testable with fstest.MapFS.
|
||
|
|
func loadMigrations(fsys fs.FS, dir string) ([]migration, error) {
|
||
|
|
entries, err := fs.ReadDir(fsys, dir)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("read migrations dir %q: %w", dir, err)
|
||
|
|
}
|
||
|
|
|
||
|
|
var out []migration
|
||
|
|
for _, e := range entries {
|
||
|
|
if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
v, err := parseMigrationVersion(e.Name())
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
body, err := fs.ReadFile(fsys, dir+"/"+e.Name())
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("read migration %s: %w", e.Name(), err)
|
||
|
|
}
|
||
|
|
out = append(out, migration{version: v, name: e.Name(), sql: string(body)})
|
||
|
|
}
|
||
|
|
|
||
|
|
sort.Slice(out, func(i, j int) bool { return out[i].version < out[j].version })
|
||
|
|
|
||
|
|
// Sanity: reject duplicate versions — a typo in filenames could otherwise
|
||
|
|
// silently skip a migration.
|
||
|
|
for i := 1; i < len(out); i++ {
|
||
|
|
if out[i].version == out[i-1].version {
|
||
|
|
return nil, fmt.Errorf("duplicate migration version %d in %s and %s",
|
||
|
|
out[i].version, out[i-1].name, out[i].name)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func parseMigrationVersion(name string) (int, error) {
|
||
|
|
// Expect "NNNN_name.sql" — strip the suffix and take the leading numeric
|
||
|
|
// prefix. Be strict so typos aren't silently accepted.
|
||
|
|
base := strings.TrimSuffix(name, ".sql")
|
||
|
|
underscore := strings.IndexByte(base, '_')
|
||
|
|
if underscore <= 0 {
|
||
|
|
return 0, fmt.Errorf("migration %q: expected NNNN_name.sql", name)
|
||
|
|
}
|
||
|
|
v, err := strconv.Atoi(base[:underscore])
|
||
|
|
if err != nil {
|
||
|
|
return 0, fmt.Errorf("migration %q: non-numeric version: %w", name, err)
|
||
|
|
}
|
||
|
|
if v <= 0 {
|
||
|
|
return 0, fmt.Errorf("migration %q: version must be > 0", name)
|
||
|
|
}
|
||
|
|
return v, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func applyMigration(ctx context.Context, db *sql.DB, m migration) error {
|
||
|
|
tx, err := db.BeginTx(ctx, nil)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("begin: %w", err)
|
||
|
|
}
|
||
|
|
if _, err := tx.ExecContext(ctx, m.sql); err != nil {
|
||
|
|
return errors.Join(fmt.Errorf("exec: %w", err), tx.Rollback())
|
||
|
|
}
|
||
|
|
if _, err := tx.ExecContext(ctx,
|
||
|
|
`INSERT INTO schema_migrations (version, name) VALUES (?, ?)`,
|
||
|
|
m.version, m.name); err != nil {
|
||
|
|
return errors.Join(fmt.Errorf("record: %w", err), tx.Rollback())
|
||
|
|
}
|
||
|
|
if err := tx.Commit(); err != nil {
|
||
|
|
return fmt.Errorf("commit: %w", err)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|