diff --git a/migrate/migrate.go b/migrate/migrate.go index dc2f32a..ef581f1 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -22,6 +22,26 @@ import ( "github.com/scylladb/gocqlx/qb" ) +// DefaultAwaitSchemaAgreement controls whether checking for cluster schema agreement +// is disabled or if it is checked before each file or statement is applied. +// The default is not checking before each file or statement but only once after every +// migration has been run. +var DefaultAwaitSchemaAgreement = AwaitSchemaAgreementDisabled + +type awaitSchemaAgreement int + +// Options for checking schema agreement. +const ( + AwaitSchemaAgreementDisabled awaitSchemaAgreement = iota + AwaitSchemaAgreementBeforeEachFile + AwaitSchemaAgreementBeforeEachStatement +) + +// ShouldAwait decides whether to await schema agreement for the configured DefaultAwaitSchemaAgreement option above. +func (as awaitSchemaAgreement) ShouldAwait(stage awaitSchemaAgreement) bool { + return as == stage +} + const ( infoSchema = `CREATE TABLE IF NOT EXISTS gocqlx_migrate ( name text, @@ -117,6 +137,10 @@ func Migrate(ctx context.Context, session *gocql.Session, dir string) error { } } + if err = session.AwaitSchemaAgreement(ctx); err != nil { + return fmt.Errorf("awaiting schema agreement failed: %s", err) + } + return nil } @@ -149,6 +173,12 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do iq := gocqlx.Query(session.Query(stmt).WithContext(ctx), names) defer iq.Release() + if DefaultAwaitSchemaAgreement.ShouldAwait(AwaitSchemaAgreementBeforeEachFile) { + if err = session.AwaitSchemaAgreement(ctx); err != nil { + return fmt.Errorf("awaiting schema agreement failed: %s", err) + } + } + i := 0 r := bytes.NewBuffer(b) for { @@ -176,6 +206,12 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do } } + if DefaultAwaitSchemaAgreement.ShouldAwait(AwaitSchemaAgreementBeforeEachStatement) { + if err = session.AwaitSchemaAgreement(ctx); err != nil { + return fmt.Errorf("awaiting schema agreement failed: %s", err) + } + } + // execute q := gocqlx.Query(session.Query(stmt).RetryPolicy(nil).WithContext(ctx), nil) if err := q.ExecRelease(); err != nil {