Files
maxwarden/cmd/metagen/migrations.go
2025-03-06 23:54:11 -05:00

243 lines
4.9 KiB
Go

package main
import (
"errors"
"fmt"
"log"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
// create
func maybeCreateSqliteDb() {
if _, err := os.Stat("./passwords.db"); err != nil {
fmt.Printf("Creating new sqlite database")
err := os.WriteFile("./passwords.db", nil, 0755)
if err != nil {
fmt.Printf("Error creating Sqlite database.")
os.Exit(1)
}
m, err := migrate.New(
"file://./migrations",
"sqlite3://passwords.db",
)
if err != nil {
printStatus(false)
fmt.Println(err.Error())
os.Exit(1)
}
mErr := m.Up()
if mErr != nil {
printStatus(false)
fmt.Println(mErr.Error())
os.Exit(1)
}
printStatus(true)
}
}
// handle the running and creation of migrations
func migrations(args []string) {
if len(args) < 2 {
fmt.Println("Usage: metagen migrate [up, down, goto {V}, create {migration name}]")
os.Exit(1)
}
maybeCreateSqliteDb()
m, err := migrate.New(
"file://./migrations",
"sqlite3://passwords.db",
)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
}
migrateNum := 0
if len(args) >= 3 && args[1] != "create" {
var parseErr error
migrateNum, parseErr = strconv.Atoi(args[2])
if parseErr != nil {
fmt.Println("Please provide a valid migration number.")
os.Exit(1)
}
}
m.PrefetchMigrations = migrate.DefaultPrefetchMigrations
switch args[1] {
case "up":
err := m.Up()
if err != nil {
fmt.Println(err.Error())
}
case "down":
err := m.Down()
if err != nil {
fmt.Println(err.Error())
}
case "goto":
err := m.Migrate(uint(migrateNum))
if err != nil {
fmt.Println(err.Error())
}
case "create":
if len(args) < 3 {
fmt.Println("Please provide a name for the new migration.")
os.Exit(1)
}
createCmd("./migrations", time.Now(), defaultTimeFormat, args[2], "sql", true, 7, true)
}
}
const (
defaultTimeFormat = "20060102150405"
defaultTimezone = "UTC"
)
var (
errInvalidSequenceWidth = errors.New("Digits must be positive")
errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive")
errInvalidTimeFormat = errors.New("Time format may not be empty")
)
func createFile(filename string) error {
// create exclusive (fails if file already exists)
// os.Create() specifies 0666 as the FileMode, so we're doing the same
f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
return err
}
return f.Close()
}
func nextSeqVersion(matches []string, seqDigits int) (string, error) {
if seqDigits <= 0 {
return "", errInvalidSequenceWidth
}
nextSeq := uint64(1)
if len(matches) > 0 {
filename := matches[len(matches)-1]
matchSeqStr := filepath.Base(filename)
idx := strings.Index(matchSeqStr, "_")
if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit
return "", fmt.Errorf("Malformed migration filename: %s", filename)
}
var err error
matchSeqStr = matchSeqStr[0:idx]
nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64)
if err != nil {
return "", err
}
nextSeq++
}
version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits)
if len(version) > seqDigits {
return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits)
}
return version, nil
}
func timeVersion(startTime time.Time, format string) (version string, err error) {
switch format {
case "":
err = errInvalidTimeFormat
case "unix":
version = strconv.FormatInt(startTime.Unix(), 10)
case "unixNano":
version = strconv.FormatInt(startTime.UnixNano(), 10)
default:
version = startTime.Format(format)
}
return
}
func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int, print bool) error {
if seq && format != defaultTimeFormat {
return errIncompatibleSeqAndFormat
}
var version string
var err error
dir = filepath.Clean(dir)
ext = "." + strings.TrimPrefix(ext, ".")
if seq {
matches, err := filepath.Glob(filepath.Join(dir, "*"+ext))
if err != nil {
return err
}
version, err = nextSeqVersion(matches, seqDigits)
if err != nil {
return err
}
} else {
version, err = timeVersion(startTime, format)
if err != nil {
return err
}
}
versionGlob := filepath.Join(dir, version+"_*"+ext)
matches, err := filepath.Glob(versionGlob)
if err != nil {
return err
}
if len(matches) > 0 {
return fmt.Errorf("duplicate migration version: %s", version)
}
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
return err
}
for _, direction := range []string{"up", "down"} {
basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext)
filename := filepath.Join(dir, basename)
if err = createFile(filename); err != nil {
return err
}
if print {
absPath, _ := filepath.Abs(filename)
log.Println(absPath)
}
}
return nil
}