diff --git a/migrate/callback.go b/migrate/callback.go new file mode 100644 index 0000000..9b8a911 --- /dev/null +++ b/migrate/callback.go @@ -0,0 +1,28 @@ +// Copyright (C) 2017 ScyllaDB +// Use of this source code is governed by a ALv2-style +// license that can be found in the LICENSE file. + +package migrate + +import ( + "context" + + "github.com/gocql/gocql" +) + +// CallbackEvent specifies type of the event when calling CallbackFunc. +type CallbackEvent uint8 + +// enumeration of CallbackEvents +const ( + BeforeMigration CallbackEvent = iota + AfterMigration +) + +// CallbackFunc enables interrupting the migration process and executing code +// while migrating. If error is returned the migration is aborted. +type CallbackFunc func(ctx context.Context, session *gocql.Session, ev CallbackEvent, name string) error + +// Callback is called before and after each migration. +// See CallbackFunc for details. +var Callback CallbackFunc diff --git a/migrate/migrate.go b/migrate/migrate.go index dc355a8..dc1fe45 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -151,8 +151,13 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do iq := gocqlx.Query(session.Query(stmt).WithContext(ctx), names) defer iq.Release() - i := 1 - stmtCount := 0 + if Callback != nil { + if err := Callback(ctx, session, BeforeMigration, info.Name); err != nil { + return fmt.Errorf("before migration callback failed: %s", err) + } + } + + i := 0 r := bytes.NewBuffer(b) for { stmt, err := r.ReadString(';') @@ -162,10 +167,9 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do if err != nil { return err } - stmtCount++ + i++ if i <= done { - i++ continue } @@ -181,12 +185,16 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do if err := iq.BindStruct(info).Exec(); err != nil { return fmt.Errorf("migration statement %d failed: %s", i, err) } - - i++ } - if stmtCount == 0 { + if i == 0 { return fmt.Errorf("no migration statements found in %q", info.Name) } + if Callback != nil { + if err := Callback(ctx, session, AfterMigration, info.Name); err != nil { + return fmt.Errorf("after migration callback failed: %s", err) + } + } + return nil }