diff --git a/benchmark_test.go b/benchmark_test.go index 8873720..cbddb8d 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -11,7 +11,6 @@ import ( "os" "testing" - "github.com/gocql/gocql" "github.com/scylladb/gocqlx" . "github.com/scylladb/gocqlx/gocqlxtest" "github.com/scylladb/gocqlx/qb" @@ -49,12 +48,12 @@ func BenchmarkBaseGocqlInsert(b *testing.B) { session := CreateSession(b) defer session.Close() - if err := ExecStmt(session, benchPersonSchema); err != nil { + if err := session.ExecStmt(benchPersonSchema); err != nil { b.Fatal(err) } stmt, _ := qb.Insert("gocqlx_test.bench_person").Columns(benchPersonCols...).ToCql() - q := session.Query(stmt) + q := session.Session.Query(stmt) defer q.Release() b.ResetTimer() @@ -72,12 +71,12 @@ func BenchmarkGocqlxInsert(b *testing.B) { session := CreateSession(b) defer session.Close() - if err := ExecStmt(session, benchPersonSchema); err != nil { + if err := session.ExecStmt(benchPersonSchema); err != nil { b.Fatal(err) } stmt, names := qb.Insert("gocqlx_test.bench_person").Columns(benchPersonCols...).ToCql() - q := gocqlx.Query(session.Query(stmt), names) + q := session.Query(stmt, names) defer q.Release() b.ResetTimer() @@ -102,7 +101,7 @@ func BenchmarkBaseGocqlGet(b *testing.B) { initTable(b, session, people) stmt, _ := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Where(qb.Eq("id")).Limit(1).ToCql() - q := session.Query(stmt) + q := session.Session.Query(stmt) defer q.Release() var p benchPerson @@ -124,8 +123,8 @@ func BenchmarkGocqlxGet(b *testing.B) { initTable(b, session, people) - stmt, _ := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Where(qb.Eq("id")).Limit(1).ToCql() - q := session.Query(stmt) + stmt, names := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Where(qb.Eq("id")).Limit(1).ToCql() + q := session.Query(stmt, names) defer q.Release() var p benchPerson @@ -133,7 +132,7 @@ func BenchmarkGocqlxGet(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { q.Bind(people[i%len(people)].ID) - if err := gocqlx.Query(q, nil).Get(&p); err != nil { + if err := q.Get(&p); err != nil { b.Fatal(err) } } @@ -153,7 +152,7 @@ func BenchmarkBaseGocqlSelect(b *testing.B) { initTable(b, session, people) stmt, _ := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Limit(100).ToCql() - q := session.Query(stmt) + q := session.Session.Query(stmt) defer q.Release() b.ResetTimer() @@ -179,8 +178,8 @@ func BenchmarkGocqlxSelect(b *testing.B) { initTable(b, session, people) - stmt, _ := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Limit(100).ToCql() - q := gocqlx.Query(session.Query(stmt), nil) + stmt, names := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Limit(100).ToCql() + q := session.Query(stmt, names) defer q.Release() b.ResetTimer() @@ -211,13 +210,13 @@ func loadFixtures() []*benchPerson { return v } -func initTable(b *testing.B, session *gocql.Session, people []*benchPerson) { - if err := ExecStmt(session, benchPersonSchema); err != nil { +func initTable(b *testing.B, session gocqlx.Session, people []*benchPerson) { + if err := session.ExecStmt(benchPersonSchema); err != nil { b.Fatal(err) } stmt, names := qb.Insert("gocqlx_test.bench_person").Columns(benchPersonCols...).ToCql() - q := gocqlx.Query(session.Query(stmt), names) + q := session.Query(stmt, names) for _, p := range people { if err := q.BindStruct(p).Exec(); err != nil { diff --git a/doc_test.go b/doc_test.go index 7af22d4..7dc1f92 100644 --- a/doc_test.go +++ b/doc_test.go @@ -1,9 +1,26 @@ +// Copyright (C) 2017 ScyllaDB +// Use of this source code is governed by a ALv2-style +// license that can be found in the LICENSE file. + package gocqlx_test import ( + "github.com/gocql/gocql" "github.com/scylladb/gocqlx" + "github.com/scylladb/gocqlx/qb" ) +func ExampleSession() { + cluster := gocql.NewCluster("host") + session, err := gocqlx.WrapSession(cluster.CreateSession()) + if err != nil { + // handle error + } + + builder := qb.Select("foo") + session.Query(builder.ToCql()) +} + func ExampleUDT() { // Just add gocqlx.UDT to a type, no need to implement marshalling functions type FullName struct { diff --git a/example_test.go b/example_test.go index ddb88fc..224d754 100644 --- a/example_test.go +++ b/example_test.go @@ -27,7 +27,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( PRIMARY KEY(first_name, last_name) )` - if err := ExecStmt(session, personSchema); err != nil { + if err := session.ExecStmt(personSchema); err != nil { t.Fatal("create table:", err) } @@ -49,7 +49,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( // Insert, bind data from struct. { stmt, names := qb.Insert("gocqlx_test.person").Columns("first_name", "last_name", "email").ToCql() - q := gocqlx.Query(session.Query(stmt), names).BindStruct(p) + q := session.Query(stmt, names).BindStruct(p) if err := q.ExecRelease(); err != nil { t.Fatal(err) @@ -63,9 +63,9 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( TTL(86400 * time.Second). Timestamp(time.Now()). ToCql() + q := session.Query(stmt, names).BindStruct(p) - err := gocqlx.Query(session.Query(stmt), names).BindStruct(p).ExecRelease() - if err != nil { + if err := q.ExecRelease(); err != nil { t.Fatal(err) } } @@ -78,7 +78,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( Set("email"). Where(qb.Eq("first_name"), qb.Eq("last_name")). ToCql() - q := gocqlx.Query(session.Query(stmt), names).BindStruct(p) + q := session.Query(stmt, names).BindStruct(p) if err := q.ExecRelease(); err != nil { t.Fatal(err) @@ -91,7 +91,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( AddNamed("email", "new_email"). Where(qb.Eq("first_name"), qb.Eq("last_name")). ToCql() - q := gocqlx.Query(session.Query(stmt), names).BindStructMap(p, qb.M{ + q := session.Query(stmt, names).BindStructMap(p, qb.M{ "new_email": []string{"patricia2.citzen@gocqlx_test.com", "patricia3.citzen@gocqlx_test.com"}, }) @@ -124,7 +124,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( []string{"ian.citzen@gocqlx_test.com"}, }, } - q := gocqlx.Query(session.Query(stmt), names).BindStruct(&batch) + q := session.Query(stmt, names).BindStruct(&batch) if err := q.ExecRelease(); err != nil { t.Fatal(err) @@ -136,7 +136,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( var p Person stmt, names := qb.Select("gocqlx_test.person").Where(qb.Eq("first_name")).ToCql() - q := gocqlx.Query(session.Query(stmt), names).BindMap(qb.M{ + q := session.Query(stmt, names).BindMap(qb.M{ "first_name": "Patricia", }) @@ -153,7 +153,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( var people []Person stmt, names := qb.Select("gocqlx_test.person").Where(qb.In("first_name")).ToCql() - q := gocqlx.Query(session.Query(stmt), names).BindMap(qb.M{ + q := session.Query(stmt, names).BindMap(qb.M{ "first_name": []string{"Patricia", "Igy", "Ian"}, }) @@ -178,7 +178,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( Where(qb.Token("first_name").Gt()). Limit(10). ToCql() - q := gocqlx.Query(session.Query(stmt), names).BindStruct(p) + q := session.Query(stmt, names).BindStruct(p) var people []Person if err := q.SelectRelease(&people); err != nil { @@ -202,7 +202,8 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( "Citizen", []string{"jane.citzen@gocqlx_test.com"}, } - q := gocqlx.Query(session.Query(stmt), names).BindStruct(p) + q := session.Query(stmt, names).BindStruct(p) + if err := q.ExecRelease(); err != nil { t.Fatal(err) } diff --git a/gocqlxtest/gocqlxtest.go b/gocqlxtest/gocqlxtest.go index 552793e..048f519 100644 --- a/gocqlxtest/gocqlxtest.go +++ b/gocqlxtest/gocqlxtest.go @@ -13,6 +13,7 @@ import ( "time" "github.com/gocql/gocql" + "github.com/scylladb/gocqlx" ) var ( @@ -28,7 +29,7 @@ var ( var initOnce sync.Once // CreateSession creates a new gocql session from flags. -func CreateSession(tb testing.TB) *gocql.Session { +func CreateSession(tb testing.TB) gocqlx.Session { cluster := createCluster() return createSessionFromCluster(cluster, tb) } @@ -60,7 +61,7 @@ func createCluster() *gocql.ClusterConfig { return cluster } -func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocql.Session { +func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) gocqlx.Session { // Drop and re-create the keyspace once. Different tests should use their own // individual tables, but can assume that the table does not exist before. initOnce.Do(func() { @@ -68,11 +69,10 @@ func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocq }) cluster.Keyspace = "gocqlx_test" - session, err := cluster.CreateSession() + session, err := gocqlx.WrapSession(cluster.CreateSession()) if err != nil { tb.Fatal("CreateSession:", err) } - return session } @@ -80,31 +80,23 @@ func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string c := *cluster c.Keyspace = "system" c.Timeout = 30 * time.Second - session, err := c.CreateSession() + session, err := gocqlx.WrapSession(c.CreateSession()) if err != nil { tb.Fatal(err) } defer session.Close() - err = ExecStmt(session, `DROP KEYSPACE IF EXISTS `+keyspace) + err = session.ExecStmt(`DROP KEYSPACE IF EXISTS ` + keyspace) if err != nil { tb.Fatalf("unable to drop keyspace: %v", err) } - err = ExecStmt(session, fmt.Sprintf(`CREATE KEYSPACE %s + err = session.ExecStmt(fmt.Sprintf(`CREATE KEYSPACE %s WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : %d }`, keyspace, *flagRF)) - if err != nil { tb.Fatalf("unable to create keyspace: %v", err) } } - -// ExecStmt executes a statement string. -func ExecStmt(s *gocql.Session, stmt string) error { - q := s.Query(stmt).RetryPolicy(nil) - defer q.Release() - return q.Exec() -} diff --git a/iterx.go b/iterx.go index 6633be8..52be4a1 100644 --- a/iterx.go +++ b/iterx.go @@ -31,15 +31,6 @@ type Iterx struct { values []interface{} } -// Iter creates a new Iterx from gocql.Query using a default mapper. -func Iter(q *gocql.Query) *Iterx { - return &Iterx{ - Iter: q.Iter(), - Mapper: DefaultMapper, - unsafe: DefaultUnsafe, - } -} - // Unsafe forces the iterator to ignore missing fields. By default when scanning // a struct if result row has a column that cannot be mapped to any destination // field an error is reported. With unsafe such columns are ignored. diff --git a/iterx_test.go b/iterx_test.go index b429093..d38d60c 100644 --- a/iterx_test.go +++ b/iterx_test.go @@ -49,11 +49,11 @@ func TestStruct(t *testing.T) { session := CreateSession(t) defer session.Close() - if err := ExecStmt(session, `CREATE TYPE gocqlx_test.FullName (first_Name text, last_name text)`); err != nil { + if err := session.ExecStmt(`CREATE TYPE gocqlx_test.FullName (first_Name text, last_name text)`); err != nil { t.Fatal("create type:", err) } - if err := ExecStmt(session, `CREATE TABLE gocqlx_test.struct_table ( + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.struct_table ( testuuid timeuuid PRIMARY KEY, testtimestamp timestamp, testvarchar varchar, @@ -123,7 +123,9 @@ func TestStruct(t *testing.T) { Testptrudt: FullNamePtrUDT{FullName: &FullName{FirstName: "John", LastName: "Doe"}}, } - if err := gocqlx.Query(session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt, testptrudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`), nil).Bind( + const stmt = `INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt, testptrudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` + + if err := session.Query(stmt, nil).Bind( m.Testuuid, m.Testtimestamp, m.Testvarchar, @@ -141,13 +143,13 @@ func TestStruct(t *testing.T) { m.Testinet, m.Testcustom, m.Testudt, - m.Testptrudt).Exec(); err != nil { + m.Testptrudt).ExecRelease(); err != nil { t.Fatal("insert:", err) } t.Run("get", func(t *testing.T) { var v StructTable - if err := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Get(&v); err != nil { + if err := session.Query(`SELECT * FROM struct_table`, nil).Get(&v); err != nil { t.Fatal("get failed", err) } @@ -158,7 +160,7 @@ func TestStruct(t *testing.T) { t.Run("select", func(t *testing.T) { var v []StructTable - if err := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Select(&v); err != nil { + if err := session.Query(`SELECT * FROM struct_table`, nil).Select(&v); err != nil { t.Fatal("select failed", err) } @@ -173,7 +175,7 @@ func TestStruct(t *testing.T) { t.Run("select ptr", func(t *testing.T) { var v []*StructTable - if err := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Select(&v); err != nil { + if err := session.Query(`SELECT * FROM struct_table`, nil).Select(&v); err != nil { t.Fatal("select failed", err) } @@ -192,7 +194,7 @@ func TestStruct(t *testing.T) { n int ) - i := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Iter() + i := session.Query(`SELECT * FROM struct_table`, nil).Iter() for i.StructScan(&v) { n++ } @@ -212,18 +214,18 @@ func TestStruct(t *testing.T) { func TestScannable(t *testing.T) { session := CreateSession(t) defer session.Close() - if err := ExecStmt(session, `CREATE TABLE gocqlx_test.scannable_table (testfullname text PRIMARY KEY)`); err != nil { + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.scannable_table (testfullname text PRIMARY KEY)`); err != nil { t.Fatal("create table:", err) } m := FullName{"John", "Doe"} - if err := session.Query(`INSERT INTO scannable_table (testfullname) values (?)`, m).Exec(); err != nil { + if err := session.Query(`INSERT INTO scannable_table (testfullname) values (?)`, nil).Bind(m).Exec(); err != nil { t.Fatal("insert:", err) } t.Run("get", func(t *testing.T) { var v FullName - if err := gocqlx.Query(session.Query(`SELECT testfullname FROM scannable_table`), nil).Get(&v); err != nil { + if err := session.Query(`SELECT testfullname FROM scannable_table`, nil).Get(&v); err != nil { t.Fatal("get failed", err) } @@ -234,7 +236,7 @@ func TestScannable(t *testing.T) { t.Run("select", func(t *testing.T) { var v []FullName - if err := gocqlx.Query(session.Query(`SELECT testfullname FROM scannable_table`), nil).Select(&v); err != nil { + if err := session.Query(`SELECT testfullname FROM scannable_table`, nil).Select(&v); err != nil { t.Fatal("select failed", err) } @@ -249,7 +251,7 @@ func TestScannable(t *testing.T) { t.Run("select ptr", func(t *testing.T) { var v []*FullName - if err := gocqlx.Query(session.Query(`SELECT testfullname FROM scannable_table`), nil).Select(&v); err != nil { + if err := session.Query(`SELECT testfullname FROM scannable_table`, nil).Select(&v); err != nil { t.Fatal("select failed", err) } @@ -266,19 +268,19 @@ func TestScannable(t *testing.T) { func TestStructOnly(t *testing.T) { session := CreateSession(t) defer session.Close() - if err := ExecStmt(session, `CREATE TABLE gocqlx_test.struct_only_table (first_name text, last_name text, PRIMARY KEY (first_name, last_name))`); err != nil { + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.struct_only_table (first_name text, last_name text, PRIMARY KEY (first_name, last_name))`); err != nil { t.Fatal("create table:", err) } m := FullName{"John", "Doe"} - if err := session.Query(`INSERT INTO struct_only_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil { + if err := session.Query(`INSERT INTO struct_only_table (first_name, last_name) values (?, ?)`, nil).Bind(m.FirstName, m.LastName).Exec(); err != nil { t.Fatal("insert:", err) } t.Run("get", func(t *testing.T) { var v FullName - if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_table`)).StructOnly().Get(&v); err != nil { + if err := session.Query(`SELECT first_name, last_name FROM struct_only_table`, nil).Iter().StructOnly().Get(&v); err != nil { t.Fatal("get failed", err) } @@ -289,7 +291,7 @@ func TestStructOnly(t *testing.T) { t.Run("select", func(t *testing.T) { var v []FullName - if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_table`)).StructOnly().Select(&v); err != nil { + if err := session.Query(`SELECT first_name, last_name FROM struct_only_table`, nil).Iter().StructOnly().Select(&v); err != nil { t.Fatal("select failed", err) } @@ -304,7 +306,7 @@ func TestStructOnly(t *testing.T) { t.Run("select ptr", func(t *testing.T) { var v []*FullName - if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_table`)).StructOnly().Select(&v); err != nil { + if err := session.Query(`SELECT first_name, last_name FROM struct_only_table`, nil).Iter().StructOnly().Select(&v); err != nil { t.Fatal("select failed", err) } @@ -319,7 +321,7 @@ func TestStructOnly(t *testing.T) { t.Run("get error", func(t *testing.T) { var v FullName - err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_table`)).Get(&v) + err := session.Query(`SELECT first_name, last_name FROM struct_only_table`, nil).Get(&v) if err == nil || !strings.HasPrefix(err.Error(), "expected 1 column in result") { t.Fatal("get expected validation error got", err) } @@ -327,7 +329,7 @@ func TestStructOnly(t *testing.T) { t.Run("select error", func(t *testing.T) { var v []FullName - err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_table`)).Select(&v) + err := session.Query(`SELECT first_name, last_name FROM struct_only_table`, nil).Select(&v) if err == nil || !strings.HasPrefix(err.Error(), "expected 1 column in result") { t.Fatal("select expected validation error got", err) } @@ -337,7 +339,7 @@ func TestStructOnly(t *testing.T) { func TestStructOnlyUDT(t *testing.T) { session := CreateSession(t) defer session.Close() - if err := ExecStmt(session, `CREATE TABLE gocqlx_test.struct_only_udt_table (first_name text, last_name text, PRIMARY KEY (first_name, last_name))`); err != nil { + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.struct_only_udt_table (first_name text, last_name text, PRIMARY KEY (first_name, last_name))`); err != nil { t.Fatal("create table:", err) } @@ -348,13 +350,13 @@ func TestStructOnlyUDT(t *testing.T) { }, } - if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil { + if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, nil).Bind(m.FirstName, m.LastName).Exec(); err != nil { t.Fatal("insert:", err) } t.Run("get", func(t *testing.T) { var v FullNameUDT - if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`)).StructOnly().Get(&v); err != nil { + if err := session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`, nil).Iter().StructOnly().Get(&v); err != nil { t.Fatal("get failed", err) } @@ -365,7 +367,7 @@ func TestStructOnlyUDT(t *testing.T) { t.Run("select", func(t *testing.T) { var v []FullNameUDT - if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`)).StructOnly().Select(&v); err != nil { + if err := session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`, nil).Iter().StructOnly().Select(&v); err != nil { t.Fatal("select failed", err) } @@ -380,7 +382,7 @@ func TestStructOnlyUDT(t *testing.T) { t.Run("select ptr", func(t *testing.T) { var v []*FullNameUDT - if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`)).StructOnly().Select(&v); err != nil { + if err := session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`, nil).Iter().StructOnly().Select(&v); err != nil { t.Fatal("select failed", err) } @@ -395,7 +397,7 @@ func TestStructOnlyUDT(t *testing.T) { t.Run("get error", func(t *testing.T) { var v FullNameUDT - err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`)).Get(&v) + err := session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`, nil).Get(&v) if err == nil || !strings.HasPrefix(err.Error(), "expected 1 column in result") { t.Fatal("get expected validation error got", err) } @@ -403,7 +405,7 @@ func TestStructOnlyUDT(t *testing.T) { t.Run("select error", func(t *testing.T) { var v []FullNameUDT - err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`)).Select(&v) + err := session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`, nil).Select(&v) if err == nil || !strings.HasPrefix(err.Error(), "expected 1 column in result") { t.Fatal("select expected validation error got", err) } @@ -413,10 +415,10 @@ func TestStructOnlyUDT(t *testing.T) { func TestUnsafe(t *testing.T) { session := CreateSession(t) defer session.Close() - if err := ExecStmt(session, `CREATE TABLE gocqlx_test.unsafe_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil { + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.unsafe_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil { t.Fatal("create table:", err) } - if err := session.Query(`INSERT INTO unsafe_table (testtext, testtextunbound) values (?, ?)`, "test", "test").Exec(); err != nil { + if err := session.Query(`INSERT INTO unsafe_table (testtext, testtextunbound) values (?, ?)`, nil).Bind("test", "test").Exec(); err != nil { t.Fatal("insert:", err) } @@ -426,16 +428,16 @@ func TestUnsafe(t *testing.T) { t.Run("safe get", func(t *testing.T) { var v UnsafeTable - i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) - if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" { + err := session.Query(`SELECT * FROM unsafe_table`, nil).Get(&v) + if err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" { t.Fatal("expected ErrNotFound", "got", err) } }) t.Run("safe select", func(t *testing.T) { var v []UnsafeTable - i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) - if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" { + err := session.Query(`SELECT * FROM unsafe_table`, nil).Select(&v) + if err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" { t.Fatal("expected ErrNotFound", "got", err) } if cap(v) > 0 { @@ -445,8 +447,8 @@ func TestUnsafe(t *testing.T) { t.Run("unsafe get", func(t *testing.T) { var v UnsafeTable - i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) - if err := i.Unsafe().Get(&v); err != nil { + err := session.Query(`SELECT * FROM unsafe_table`, nil).Iter().Unsafe().Get(&v) + if err != nil { t.Fatal(err) } if v.Testtext != "test" { @@ -456,8 +458,8 @@ func TestUnsafe(t *testing.T) { t.Run("unsafe select", func(t *testing.T) { var v []UnsafeTable - i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) - if err := i.Unsafe().Select(&v); err != nil { + err := session.Query(`SELECT * FROM unsafe_table`, nil).Iter().Unsafe().Select(&v) + if err != nil { t.Fatal(err) } if len(v) != 1 { @@ -470,10 +472,12 @@ func TestUnsafe(t *testing.T) { t.Run("DefaultUnsafe select", func(t *testing.T) { gocqlx.DefaultUnsafe = true - defer func() { gocqlx.DefaultUnsafe = false }() + defer func() { + gocqlx.DefaultUnsafe = false + }() var v []UnsafeTable - i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) - if err := i.Select(&v); err != nil { + err := session.Query(`SELECT * FROM unsafe_table`, nil).Iter().Select(&v) + if err != nil { t.Fatal(err) } if len(v) != 1 { @@ -488,7 +492,7 @@ func TestUnsafe(t *testing.T) { func TestNotFound(t *testing.T) { session := CreateSession(t) defer session.Close() - if err := ExecStmt(session, `CREATE TABLE gocqlx_test.not_found_table (testtext text PRIMARY KEY)`); err != nil { + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.not_found_table (testtext text PRIMARY KEY)`); err != nil { t.Fatal("create table:", err) } @@ -498,9 +502,7 @@ func TestNotFound(t *testing.T) { t.Run("get cql error", func(t *testing.T) { var v NotFoundTable - i := gocqlx.Iter(session.Query(`SELECT * FROM not_found_table WRONG`).RetryPolicy(nil)) - - err := i.Get(&v) + err := session.Query(`SELECT * FROM not_found_table WRONG`, nil).RetryPolicy(nil).Get(&v) if err == nil || !strings.Contains(err.Error(), "WRONG") { t.Fatal(err) } @@ -508,17 +510,15 @@ func TestNotFound(t *testing.T) { t.Run("get", func(t *testing.T) { var v NotFoundTable - i := gocqlx.Iter(session.Query(`SELECT * FROM not_found_table`)) - if err := i.Get(&v); err != gocql.ErrNotFound { + err := session.Query(`SELECT * FROM not_found_table`, nil).Get(&v) + if err != gocql.ErrNotFound { t.Fatal("expected ErrNotFound", "got", err) } }) t.Run("select cql error", func(t *testing.T) { var v []NotFoundTable - i := gocqlx.Iter(session.Query(`SELECT * FROM not_found_table WRONG`).RetryPolicy(nil)) - - err := i.Select(&v) + err := session.Query(`SELECT * FROM not_found_table WRONG`, nil).RetryPolicy(nil).Select(&v) if err == nil || !strings.Contains(err.Error(), "WRONG") { t.Fatal(err) } @@ -526,8 +526,8 @@ func TestNotFound(t *testing.T) { t.Run("select", func(t *testing.T) { var v []NotFoundTable - i := gocqlx.Iter(session.Query(`SELECT * FROM not_found_table`)) - if err := i.Select(&v); err != nil { + err := session.Query(`SELECT * FROM not_found_table`, nil).Select(&v) + if err != nil { t.Fatal(err) } if cap(v) > 0 { @@ -539,7 +539,7 @@ func TestNotFound(t *testing.T) { func TestErrorOnNil(t *testing.T) { session := CreateSession(t) defer session.Close() - if err := ExecStmt(session, `CREATE TABLE gocqlx_test.nil_table (testtext text PRIMARY KEY)`); err != nil { + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.nil_table (testtext text PRIMARY KEY)`); err != nil { t.Fatal("create table:", err) } @@ -549,19 +549,19 @@ func TestErrorOnNil(t *testing.T) { ) t.Run("get", func(t *testing.T) { - err := gocqlx.Iter(session.Query(stmt)).Get(nil) + err := session.Query(stmt, nil).Get(nil) if err == nil || err.Error() != golden { t.Fatalf("Get()=%q expected %q error", err, golden) } }) t.Run("select", func(t *testing.T) { - err := gocqlx.Iter(session.Query(stmt)).Select(nil) + err := session.Query(stmt, nil).Select(nil) if err == nil || err.Error() != golden { t.Fatalf("Select()=%q expected %q error", err, golden) } }) t.Run("struct scan", func(t *testing.T) { - i := gocqlx.Iter(session.Query(stmt)) + i := session.Query(stmt, nil).Iter() i.StructScan(nil) err := i.Close() if err == nil || err.Error() != golden { @@ -573,14 +573,14 @@ func TestErrorOnNil(t *testing.T) { func TestPaging(t *testing.T) { session := CreateSession(t) defer session.Close() - if err := ExecStmt(session, `CREATE TABLE gocqlx_test.paging_table (id int PRIMARY KEY, val int)`); err != nil { + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.paging_table (id int PRIMARY KEY, val int)`); err != nil { t.Fatal("create table:", err) } - if err := ExecStmt(session, `CREATE INDEX id_val_index ON gocqlx_test.paging_table (val)`); err != nil { + if err := session.ExecStmt(`CREATE INDEX id_val_index ON gocqlx_test.paging_table (val)`); err != nil { t.Fatal("create index:", err) } - stmt, names := qb.Insert("gocqlx_test.paging_table").Columns("id", "val").ToCql() - q := gocqlx.Query(session.Query(stmt), names) + + q := session.Query(qb.Insert("gocqlx_test.paging_table").Columns("id", "val").ToCql()) for i := 0; i < 5000; i++ { if err := q.Bind(i, i).Exec(); err != nil { t.Fatal(err) @@ -597,12 +597,13 @@ func TestPaging(t *testing.T) { Where(qb.Lt("val")). AllowFiltering(). Columns("id", "val").ToCql() - it := gocqlx.Query(session.Query(stmt, 100).PageSize(10), names).Iter() - defer it.Close() + iter := session.Query(stmt, names).Bind(100).PageSize(10).Iter() + defer iter.Close() + var cnt int for { p := &Paging{} - if !it.StructScan(p) { + if !iter.StructScan(p) { break } cnt++ diff --git a/mapper.go b/mapper.go index 5b134eb..ba24e12 100644 --- a/mapper.go +++ b/mapper.go @@ -12,4 +12,6 @@ import ( // snake case. It can be set to whatever you want, but it is encouraged to be // set before gocqlx is used as name-to-field mappings are cached after first // use on a type. +// +// A custom mapper can always be set per Sessionm, Query and Iter. var DefaultMapper = reflectx.NewMapperFunc("db", reflectx.CamelToSnakeASCII) diff --git a/migrate/callback.go b/migrate/callback.go index 9b8a911..44aa505 100644 --- a/migrate/callback.go +++ b/migrate/callback.go @@ -7,7 +7,7 @@ package migrate import ( "context" - "github.com/gocql/gocql" + "github.com/scylladb/gocqlx" ) // CallbackEvent specifies type of the event when calling CallbackFunc. @@ -21,7 +21,7 @@ const ( // CallbackFunc enables interrupting the migration process and executing code // while migrating. If error is returned the migration is aborted. -type CallbackFunc func(ctx context.Context, session *gocql.Session, ev CallbackEvent, name string) error +type CallbackFunc func(ctx context.Context, session gocqlx.Session, ev CallbackEvent, name string) error // Callback is called before and after each migration. // See CallbackFunc for details. diff --git a/migrate/migrate.go b/migrate/migrate.go index 3e566b8..cf0c913 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -64,12 +64,12 @@ type Info struct { } // List provides a listing of applied migrations. -func List(ctx context.Context, session *gocql.Session) ([]*Info, error) { +func List(ctx context.Context, session gocqlx.Session) ([]*Info, error) { if err := ensureInfoTable(ctx, session); err != nil { return nil, err } - q := gocqlx.Query(session.Query(selectInfo).WithContext(ctx), nil) + q := session.ContextQuery(ctx, selectInfo, nil) var v []*Info if err := q.SelectRelease(&v); err == gocql.ErrNotFound { @@ -85,12 +85,12 @@ func List(ctx context.Context, session *gocql.Session) ([]*Info, error) { return v, nil } -func ensureInfoTable(ctx context.Context, session *gocql.Session) error { - return gocqlx.Query(session.Query(infoSchema).WithContext(ctx), nil).ExecRelease() +func ensureInfoTable(ctx context.Context, session gocqlx.Session) error { + return session.ContextQuery(ctx, infoSchema, nil).ExecRelease() } // Migrate reads the cql files from a directory and applies required migrations. -func Migrate(ctx context.Context, session *gocql.Session, dir string) error { +func Migrate(ctx context.Context, session gocqlx.Session, dir string) error { // get database migrations dbm, err := List(ctx, session) if err != nil { @@ -147,7 +147,7 @@ func Migrate(ctx context.Context, session *gocql.Session, dir string) error { return nil } -func applyMigration(ctx context.Context, session *gocql.Session, path string, done int) error { +func applyMigration(ctx context.Context, session gocqlx.Session, path string, done int) error { f, err := os.Open(path) if err != nil { return err @@ -173,8 +173,8 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do "end_time", ).ToCql() - iq := gocqlx.Query(session.Query(stmt).WithContext(ctx), names) - defer iq.Release() + update := session.ContextQuery(ctx, stmt, names) + defer update.Release() if DefaultAwaitSchemaAgreement.ShouldAwait(AwaitSchemaAgreementBeforeEachFile) { if err = session.AwaitSchemaAgreement(ctx); err != nil { @@ -216,7 +216,7 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do } // execute - q := gocqlx.Query(session.Query(stmt).RetryPolicy(nil).WithContext(ctx), nil) + q := session.ContextQuery(ctx, stmt, nil).RetryPolicy(nil) if err := q.ExecRelease(); err != nil { return fmt.Errorf("statement %d failed: %s", i, err) } @@ -224,7 +224,7 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do // update info info.Done = i info.EndTime = time.Now() - if err := iq.BindStruct(info).Exec(); err != nil { + if err := update.BindStruct(info).Exec(); err != nil { return fmt.Errorf("migration statement %d failed: %s", i, err) } } diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index 56347ac..8798131 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -15,7 +15,7 @@ import ( "strings" "testing" - "github.com/gocql/gocql" + "github.com/scylladb/gocqlx" . "github.com/scylladb/gocqlx/gocqlxtest" "github.com/scylladb/gocqlx/migrate" ) @@ -30,16 +30,16 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.migrate_table ( var insertMigrate = `INSERT INTO gocqlx_test.migrate_table (testint, testuuid) VALUES (%d, now())` -func recreateTables(tb testing.TB, session *gocql.Session) { +func recreateTables(tb testing.TB, session gocqlx.Session) { tb.Helper() - if err := ExecStmt(session, "DROP TABLE IF EXISTS gocqlx_test.gocqlx_migrate"); err != nil { + if err := session.ExecStmt("DROP TABLE IF EXISTS gocqlx_test.gocqlx_migrate"); err != nil { tb.Fatal(err) } - if err := ExecStmt(session, migrateSchema); err != nil { + if err := session.ExecStmt(migrateSchema); err != nil { tb.Fatal(err) } - if err := ExecStmt(session, "TRUNCATE gocqlx_test.migrate_table"); err != nil { + if err := session.ExecStmt("TRUNCATE gocqlx_test.migrate_table"); err != nil { tb.Fatal(err) } } @@ -105,7 +105,7 @@ func TestMigrationNoSemicolon(t *testing.T) { defer session.Close() recreateTables(t, session) - if err := ExecStmt(session, migrateSchema); err != nil { + if err := session.ExecStmt(migrateSchema); err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func TestMigrationCallback(t *testing.T) { beforeCalled int afterCalled int ) - migrate.Callback = func(ctx context.Context, session *gocql.Session, ev migrate.CallbackEvent, name string) error { + migrate.Callback = func(ctx context.Context, session gocqlx.Session, ev migrate.CallbackEvent, name string) error { switch ev { case migrate.BeforeMigration: beforeCalled += 1 @@ -166,7 +166,7 @@ func TestMigrationCallback(t *testing.T) { defer session.Close() recreateTables(t, session) - if err := ExecStmt(session, migrateSchema); err != nil { + if err := session.ExecStmt(migrateSchema); err != nil { t.Fatal(err) } @@ -215,14 +215,11 @@ func makeMigrationDir(tb testing.TB, n int) (dir string) { return dir } -func countMigrations(tb testing.TB, session *gocql.Session) int { +func countMigrations(tb testing.TB, session gocqlx.Session) int { tb.Helper() - q := session.Query("SELECT COUNT(*) FROM gocqlx_test.migrate_table") - defer q.Release() - var v int - if err := q.Scan(&v); err != nil { + if err := session.Query("SELECT COUNT(*) FROM gocqlx_test.migrate_table", nil).Get(&v); err != nil { tb.Fatal(err) } return v diff --git a/queryx.go b/queryx.go index 025a675..cc5d85a 100644 --- a/queryx.go +++ b/queryx.go @@ -90,6 +90,8 @@ type Queryx struct { } // Query creates a new Queryx from gocql.Query using a default mapper. +// +// Deprecated: Use Session API instead. func Query(q *gocql.Query, names []string) *Queryx { return &Queryx{ Query: q, @@ -266,7 +268,9 @@ func (q *Queryx) SelectRelease(dest interface{}) error { // big to be loaded with Select in order to do row by row iteration. // See Iterx StructScan function. func (q *Queryx) Iter() *Iterx { - i := Iter(q.Query) - i.Mapper = q.Mapper - return i + return &Iterx{ + Iter: q.Query.Iter(), + Mapper: q.Mapper, + unsafe: DefaultUnsafe, + } } diff --git a/session.go b/session.go new file mode 100644 index 0000000..6262a2b --- /dev/null +++ b/session.go @@ -0,0 +1,60 @@ +// Copyright (C) 2017 ScyllaDB +// Use of this source code is governed by a ALv2-style +// license that can be found in the LICENSE file. + +package gocqlx + +import ( + "context" + + "github.com/gocql/gocql" + "github.com/scylladb/go-reflectx" +) + +// Session wraps gocql.Session and provides a modified Query function that +// returns Queryx instance. +// The original Session instance can be accessed as Session. +// The default mapper uses `db` tag and automatically converts struct field +// names to snake case. If needed package reflectx provides constructors +// for other types of mappers. +type Session struct { + *gocql.Session + Mapper *reflectx.Mapper +} + +// WrapSession should be called on CreateSession() gocql function to convert +// the created session to gocqlx.Session. +func WrapSession(session *gocql.Session, err error) (Session, error) { + return Session{ + Session: session, + Mapper: DefaultMapper, + }, err +} + +// ContextQuery is a helper function that allows to pass context when creating +// a query, see the "Query" function . +func (s Session) ContextQuery(ctx context.Context, stmt string, names []string) *Queryx { + return &Queryx{ + Query: s.Session.Query(stmt).WithContext(ctx), + Names: names, + Mapper: s.Mapper, + } +} + +// Query creates a new Queryx using the session mapper. +// The stmt and names parameters are typically result of a query builder +// (package qb) ToCql() function or come from table model (package table). +// The names parameter is a list of query parameters' names and it's used for +// binding. +func (s Session) Query(stmt string, names []string) *Queryx { + return &Queryx{ + Query: s.Session.Query(stmt), + Names: names, + Mapper: s.Mapper, + } +} + +// ExecStmt creates query and executes the given statement. +func (s Session) ExecStmt(stmt string) error { + return s.Query(stmt, nil).ExecRelease() +} diff --git a/table/example_test.go b/table/example_test.go index 43bf69f..2a36189 100644 --- a/table/example_test.go +++ b/table/example_test.go @@ -9,7 +9,6 @@ package table_test import ( "testing" - "github.com/scylladb/gocqlx" . "github.com/scylladb/gocqlx/gocqlxtest" "github.com/scylladb/gocqlx/qb" "github.com/scylladb/gocqlx/table" @@ -26,7 +25,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( email list, PRIMARY KEY(first_name, last_name) )` - if err := ExecStmt(session, personSchema); err != nil { + if err := session.ExecStmt(personSchema); err != nil { t.Fatal("create table:", err) } @@ -58,8 +57,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( []string{"patricia.citzen@gocqlx_test.com"}, } - stmt, names := personTable.Insert() - q := gocqlx.Query(session.Query(stmt), names).BindStruct(p) + q := session.Query(personTable.Insert()).BindStruct(p) if err := q.ExecRelease(); err != nil { t.Fatal(err) } @@ -73,8 +71,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( nil, // no email } - stmt, names := personTable.Get() // you can filter columns too - q := gocqlx.Query(session.Query(stmt), names).BindStruct(p) + q := session.Query(personTable.Get()).BindStruct(p) if err := q.GetRelease(&p); err != nil { t.Fatal(err) } @@ -87,9 +84,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( { var people []Person - stmt, names := personTable.Select() // you can filter columns too - q := gocqlx.Query(session.Query(stmt), names).BindMap(qb.M{"first_name": "Patricia"}) - + q := session.Query(personTable.Select()).BindMap(qb.M{"first_name": "Patricia"}) if err := q.SelectRelease(&people); err != nil { t.Fatal(err) }