Merge pull request #3 from hailocab/upstream-merge
Upstream merge Add Session wrapper With this patch we can now use gocqlx like: ``` session.Query(`SELECT * FROM struct_table`, nil).Get(&v) ``` instead of (old format): ``` gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Get(&v) ``` Signed-off-by: Michał Matczuk <michal@scylladb.com>
This commit is contained in:
committed by
Michal Jan Matczuk
parent
ab279e68ed
commit
95d96fa939
@@ -7,7 +7,7 @@ package migrate
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/scylladb/gocqlx"
|
||||
)
|
||||
|
||||
// CallbackEvent specifies type of the event when calling CallbackFunc.
|
||||
@@ -21,7 +21,7 @@ const (
|
||||
|
||||
// CallbackFunc enables interrupting the migration process and executing code
|
||||
// while migrating. If error is returned the migration is aborted.
|
||||
type CallbackFunc func(ctx context.Context, session *gocql.Session, ev CallbackEvent, name string) error
|
||||
type CallbackFunc func(ctx context.Context, session gocqlx.Session, ev CallbackEvent, name string) error
|
||||
|
||||
// Callback is called before and after each migration.
|
||||
// See CallbackFunc for details.
|
||||
|
||||
@@ -64,12 +64,12 @@ type Info struct {
|
||||
}
|
||||
|
||||
// List provides a listing of applied migrations.
|
||||
func List(ctx context.Context, session *gocql.Session) ([]*Info, error) {
|
||||
func List(ctx context.Context, session gocqlx.Session) ([]*Info, error) {
|
||||
if err := ensureInfoTable(ctx, session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := gocqlx.Query(session.Query(selectInfo).WithContext(ctx), nil)
|
||||
q := session.ContextQuery(ctx, selectInfo, nil)
|
||||
|
||||
var v []*Info
|
||||
if err := q.SelectRelease(&v); err == gocql.ErrNotFound {
|
||||
@@ -85,12 +85,12 @@ func List(ctx context.Context, session *gocql.Session) ([]*Info, error) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func ensureInfoTable(ctx context.Context, session *gocql.Session) error {
|
||||
return gocqlx.Query(session.Query(infoSchema).WithContext(ctx), nil).ExecRelease()
|
||||
func ensureInfoTable(ctx context.Context, session gocqlx.Session) error {
|
||||
return session.ContextQuery(ctx, infoSchema, nil).ExecRelease()
|
||||
}
|
||||
|
||||
// Migrate reads the cql files from a directory and applies required migrations.
|
||||
func Migrate(ctx context.Context, session *gocql.Session, dir string) error {
|
||||
func Migrate(ctx context.Context, session gocqlx.Session, dir string) error {
|
||||
// get database migrations
|
||||
dbm, err := List(ctx, session)
|
||||
if err != nil {
|
||||
@@ -147,7 +147,7 @@ func Migrate(ctx context.Context, session *gocql.Session, dir string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyMigration(ctx context.Context, session *gocql.Session, path string, done int) error {
|
||||
func applyMigration(ctx context.Context, session gocqlx.Session, path string, done int) error {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -173,8 +173,8 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do
|
||||
"end_time",
|
||||
).ToCql()
|
||||
|
||||
iq := gocqlx.Query(session.Query(stmt).WithContext(ctx), names)
|
||||
defer iq.Release()
|
||||
update := session.ContextQuery(ctx, stmt, names)
|
||||
defer update.Release()
|
||||
|
||||
if DefaultAwaitSchemaAgreement.ShouldAwait(AwaitSchemaAgreementBeforeEachFile) {
|
||||
if err = session.AwaitSchemaAgreement(ctx); err != nil {
|
||||
@@ -216,7 +216,7 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do
|
||||
}
|
||||
|
||||
// execute
|
||||
q := gocqlx.Query(session.Query(stmt).RetryPolicy(nil).WithContext(ctx), nil)
|
||||
q := session.ContextQuery(ctx, stmt, nil).RetryPolicy(nil)
|
||||
if err := q.ExecRelease(); err != nil {
|
||||
return fmt.Errorf("statement %d failed: %s", i, err)
|
||||
}
|
||||
@@ -224,7 +224,7 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do
|
||||
// update info
|
||||
info.Done = i
|
||||
info.EndTime = time.Now()
|
||||
if err := iq.BindStruct(info).Exec(); err != nil {
|
||||
if err := update.BindStruct(info).Exec(); err != nil {
|
||||
return fmt.Errorf("migration statement %d failed: %s", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/scylladb/gocqlx"
|
||||
. "github.com/scylladb/gocqlx/gocqlxtest"
|
||||
"github.com/scylladb/gocqlx/migrate"
|
||||
)
|
||||
@@ -30,16 +30,16 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.migrate_table (
|
||||
|
||||
var insertMigrate = `INSERT INTO gocqlx_test.migrate_table (testint, testuuid) VALUES (%d, now())`
|
||||
|
||||
func recreateTables(tb testing.TB, session *gocql.Session) {
|
||||
func recreateTables(tb testing.TB, session gocqlx.Session) {
|
||||
tb.Helper()
|
||||
|
||||
if err := ExecStmt(session, "DROP TABLE IF EXISTS gocqlx_test.gocqlx_migrate"); err != nil {
|
||||
if err := session.ExecStmt("DROP TABLE IF EXISTS gocqlx_test.gocqlx_migrate"); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := ExecStmt(session, migrateSchema); err != nil {
|
||||
if err := session.ExecStmt(migrateSchema); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := ExecStmt(session, "TRUNCATE gocqlx_test.migrate_table"); err != nil {
|
||||
if err := session.ExecStmt("TRUNCATE gocqlx_test.migrate_table"); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func TestMigrationNoSemicolon(t *testing.T) {
|
||||
defer session.Close()
|
||||
recreateTables(t, session)
|
||||
|
||||
if err := ExecStmt(session, migrateSchema); err != nil {
|
||||
if err := session.ExecStmt(migrateSchema); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ func TestMigrationCallback(t *testing.T) {
|
||||
beforeCalled int
|
||||
afterCalled int
|
||||
)
|
||||
migrate.Callback = func(ctx context.Context, session *gocql.Session, ev migrate.CallbackEvent, name string) error {
|
||||
migrate.Callback = func(ctx context.Context, session gocqlx.Session, ev migrate.CallbackEvent, name string) error {
|
||||
switch ev {
|
||||
case migrate.BeforeMigration:
|
||||
beforeCalled += 1
|
||||
@@ -166,7 +166,7 @@ func TestMigrationCallback(t *testing.T) {
|
||||
defer session.Close()
|
||||
recreateTables(t, session)
|
||||
|
||||
if err := ExecStmt(session, migrateSchema); err != nil {
|
||||
if err := session.ExecStmt(migrateSchema); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -215,14 +215,11 @@ func makeMigrationDir(tb testing.TB, n int) (dir string) {
|
||||
return dir
|
||||
}
|
||||
|
||||
func countMigrations(tb testing.TB, session *gocql.Session) int {
|
||||
func countMigrations(tb testing.TB, session gocqlx.Session) int {
|
||||
tb.Helper()
|
||||
|
||||
q := session.Query("SELECT COUNT(*) FROM gocqlx_test.migrate_table")
|
||||
defer q.Release()
|
||||
|
||||
var v int
|
||||
if err := q.Scan(&v); err != nil {
|
||||
if err := session.Query("SELECT COUNT(*) FROM gocqlx_test.migrate_table", nil).Get(&v); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return v
|
||||
|
||||
Reference in New Issue
Block a user