From b4347d2757e1a388f315a65f91ed0e80d1e38d86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Matczuk?= Date: Thu, 24 May 2018 12:54:11 +0200 Subject: [PATCH] migrate: migration integration tests --- migrate/migrate.go | 25 +++-- migrate/migrate_test.go | 237 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 11 deletions(-) create mode 100644 migrate/migrate_test.go diff --git a/migrate/migrate.go b/migrate/migrate.go index dc1fe45..dc2f32a 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -14,6 +14,7 @@ import ( "os" "path/filepath" "sort" + "strings" "time" "github.com/gocql/gocql" @@ -93,9 +94,6 @@ func Migrate(ctx context.Context, session *gocql.Session, dir string) error { fmt.Println(dbm[i].Name, filepath.Base(fm[i]), i) return errors.New("inconsistent migrations") } - } - - for i := 0; i < len(dbm); i++ { c, err := fileChecksum(fm[i]) if err != nil { return fmt.Errorf("failed to calculate checksum for %q: %s", fm[i], err) @@ -151,18 +149,17 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do iq := gocqlx.Query(session.Query(stmt).WithContext(ctx), names) defer iq.Release() - if Callback != nil { - if err := Callback(ctx, session, BeforeMigration, info.Name); err != nil { - return fmt.Errorf("before migration callback failed: %s", err) - } - } - i := 0 r := bytes.NewBuffer(b) for { stmt, err := r.ReadString(';') if err == io.EOF { - break + if strings.TrimSpace(stmt) != "" { + // handle missing semicolon after last statement + err = nil + } else { + break + } } if err != nil { return err @@ -173,6 +170,12 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do continue } + if Callback != nil && i == 1 { + if err := Callback(ctx, session, BeforeMigration, info.Name); err != nil { + return fmt.Errorf("before migration callback failed: %s", err) + } + } + // execute q := gocqlx.Query(session.Query(stmt).RetryPolicy(nil).WithContext(ctx), nil) if err := q.ExecRelease(); err != nil { @@ -190,7 +193,7 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do return fmt.Errorf("no migration statements found in %q", info.Name) } - if Callback != nil { + if Callback != nil && i > done { if err := Callback(ctx, session, AfterMigration, info.Name); err != nil { return fmt.Errorf("after migration callback failed: %s", err) } diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go new file mode 100644 index 0000000..56347ac --- /dev/null +++ b/migrate/migrate_test.go @@ -0,0 +1,237 @@ +// Copyright (C) 2017 ScyllaDB +// Use of this source code is governed by a ALv2-style +// license that can be found in the LICENSE file. + +// +build all integration + +package migrate_test + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/gocql/gocql" + . "github.com/scylladb/gocqlx/gocqlxtest" + "github.com/scylladb/gocqlx/migrate" +) + +var migrateSchema = ` +CREATE TABLE IF NOT EXISTS gocqlx_test.migrate_table ( + testint int, + testuuid timeuuid, + PRIMARY KEY(testint, testuuid) +) +` + +var insertMigrate = `INSERT INTO gocqlx_test.migrate_table (testint, testuuid) VALUES (%d, now())` + +func recreateTables(tb testing.TB, session *gocql.Session) { + tb.Helper() + + if err := ExecStmt(session, "DROP TABLE IF EXISTS gocqlx_test.gocqlx_migrate"); err != nil { + tb.Fatal(err) + } + if err := ExecStmt(session, migrateSchema); err != nil { + tb.Fatal(err) + } + if err := ExecStmt(session, "TRUNCATE gocqlx_test.migrate_table"); err != nil { + tb.Fatal(err) + } +} + +func TestMigration(t *testing.T) { + session := CreateSession(t) + defer session.Close() + recreateTables(t, session) + + ctx := context.Background() + + t.Run("init", func(t *testing.T) { + dir := makeMigrationDir(t, 2) + defer os.Remove(dir) + + if err := migrate.Migrate(ctx, session, dir); err != nil { + t.Fatal(err) + } + if c := countMigrations(t, session); c != 2 { + t.Fatal("expected 2 migration got", c) + } + }) + + t.Run("update", func(t *testing.T) { + dir := makeMigrationDir(t, 4) + defer os.Remove(dir) + + if err := migrate.Migrate(ctx, session, dir); err != nil { + t.Fatal(err) + } + if c := countMigrations(t, session); c != 4 { + t.Fatal("expected 4 migration got", c) + } + }) + + t.Run("ahead", func(t *testing.T) { + dir := makeMigrationDir(t, 2) + defer os.Remove(dir) + + if err := migrate.Migrate(ctx, session, dir); err == nil || !strings.Contains(err.Error(), "ahead") { + t.Fatal("expected error") + } else { + t.Log(err) + } + }) + + t.Run("tempered with file", func(t *testing.T) { + dir := makeMigrationDir(t, 4) + defer os.Remove(dir) + + temperFile(t, dir, "3.cql") + + if err := migrate.Migrate(ctx, session, dir); err == nil || !strings.Contains(err.Error(), "tempered") { + t.Fatal("expected error") + } else { + t.Log(err) + } + }) +} + +func TestMigrationNoSemicolon(t *testing.T) { + session := CreateSession(t) + defer session.Close() + recreateTables(t, session) + + if err := ExecStmt(session, migrateSchema); err != nil { + t.Fatal(err) + } + + ctx := context.Background() + + dir := makeMigrationDir(t, 1) + defer os.Remove(dir) + + f, err := os.OpenFile(filepath.Join(dir, "0.cql"), os.O_WRONLY|os.O_APPEND, 0) + if err != nil { + t.Fatal(err) + } + fmt.Fprintf(f, insertMigrate, 0) // note no ; at the end + f.Close() + + if err := migrate.Migrate(ctx, session, dir); err != nil { + t.Fatal(err) + } + if c := countMigrations(t, session); c != 2 { + t.Fatal("expected 2 migration got", c) + } +} + +func TestMigrationCallback(t *testing.T) { + var ( + beforeCalled int + afterCalled int + ) + migrate.Callback = func(ctx context.Context, session *gocql.Session, ev migrate.CallbackEvent, name string) error { + switch ev { + case migrate.BeforeMigration: + beforeCalled += 1 + case migrate.AfterMigration: + afterCalled += 1 + } + return nil + } + + defer func() { + migrate.Callback = nil + }() + + reset := func() { + beforeCalled = 0 + afterCalled = 0 + } + + assertCallbacks := func(t *testing.T, b, a int) { + if beforeCalled != b { + t.Fatalf("expected %d before calls got %d", b, beforeCalled) + } + if afterCalled != b { + t.Fatalf("expected %d after calls got %d", a, afterCalled) + } + } + + session := CreateSession(t) + defer session.Close() + recreateTables(t, session) + + if err := ExecStmt(session, migrateSchema); err != nil { + t.Fatal(err) + } + + ctx := context.Background() + + t.Run("init", func(t *testing.T) { + dir := makeMigrationDir(t, 2) + defer os.Remove(dir) + reset() + + if err := migrate.Migrate(ctx, session, dir); err != nil { + t.Fatal(err) + } + assertCallbacks(t, 2, 2) + }) + + t.Run("no duplicate calls", func(t *testing.T) { + dir := makeMigrationDir(t, 4) + defer os.Remove(dir) + reset() + + if err := migrate.Migrate(ctx, session, dir); err != nil { + t.Fatal(err) + } + assertCallbacks(t, 2, 2) + }) +} + +func makeMigrationDir(tb testing.TB, n int) (dir string) { + tb.Helper() + + dir, err := ioutil.TempDir("", "gocqlx_migrate") + if err != nil { + tb.Fatal(err) + } + + for i := 0; i < n; i++ { + path := filepath.Join(dir, fmt.Sprint(i, ".cql")) + cql := []byte(fmt.Sprintf(insertMigrate, i) + ";") + if err := ioutil.WriteFile(path, cql, os.ModePerm); err != nil { + os.Remove(dir) + tb.Fatal(err) + } + } + + return dir +} + +func countMigrations(tb testing.TB, session *gocql.Session) int { + tb.Helper() + + q := session.Query("SELECT COUNT(*) FROM gocqlx_test.migrate_table") + defer q.Release() + + var v int + if err := q.Scan(&v); err != nil { + tb.Fatal(err) + } + return v +} + +func temperFile(tb testing.TB, dir, name string) { + tb.Helper() + + if err := ioutil.WriteFile(filepath.Join(dir, name), []byte("SELECT * FROM bla;"), os.ModePerm); err != nil { + tb.Fatal(err) + } +}