Merge pull request #3 from hailocab/upstream-merge

Upstream merge
Add Session wrapper

With this patch we can now use gocqlx like:

```
session.Query(`SELECT * FROM struct_table`, nil).Get(&v)
```

instead of (old format):

```
gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Get(&v)
```

Signed-off-by: Michał Matczuk <michal@scylladb.com>
This commit is contained in:
Michał Matczuk
2020-04-20 17:39:17 +02:00
committed by Michal Jan Matczuk
parent ab279e68ed
commit 95d96fa939
13 changed files with 208 additions and 149 deletions

View File

@@ -11,7 +11,6 @@ import (
"os"
"testing"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx"
. "github.com/scylladb/gocqlx/gocqlxtest"
"github.com/scylladb/gocqlx/qb"
@@ -49,12 +48,12 @@ func BenchmarkBaseGocqlInsert(b *testing.B) {
session := CreateSession(b)
defer session.Close()
if err := ExecStmt(session, benchPersonSchema); err != nil {
if err := session.ExecStmt(benchPersonSchema); err != nil {
b.Fatal(err)
}
stmt, _ := qb.Insert("gocqlx_test.bench_person").Columns(benchPersonCols...).ToCql()
q := session.Query(stmt)
q := session.Session.Query(stmt)
defer q.Release()
b.ResetTimer()
@@ -72,12 +71,12 @@ func BenchmarkGocqlxInsert(b *testing.B) {
session := CreateSession(b)
defer session.Close()
if err := ExecStmt(session, benchPersonSchema); err != nil {
if err := session.ExecStmt(benchPersonSchema); err != nil {
b.Fatal(err)
}
stmt, names := qb.Insert("gocqlx_test.bench_person").Columns(benchPersonCols...).ToCql()
q := gocqlx.Query(session.Query(stmt), names)
q := session.Query(stmt, names)
defer q.Release()
b.ResetTimer()
@@ -102,7 +101,7 @@ func BenchmarkBaseGocqlGet(b *testing.B) {
initTable(b, session, people)
stmt, _ := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Where(qb.Eq("id")).Limit(1).ToCql()
q := session.Query(stmt)
q := session.Session.Query(stmt)
defer q.Release()
var p benchPerson
@@ -124,8 +123,8 @@ func BenchmarkGocqlxGet(b *testing.B) {
initTable(b, session, people)
stmt, _ := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Where(qb.Eq("id")).Limit(1).ToCql()
q := session.Query(stmt)
stmt, names := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Where(qb.Eq("id")).Limit(1).ToCql()
q := session.Query(stmt, names)
defer q.Release()
var p benchPerson
@@ -133,7 +132,7 @@ func BenchmarkGocqlxGet(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
q.Bind(people[i%len(people)].ID)
if err := gocqlx.Query(q, nil).Get(&p); err != nil {
if err := q.Get(&p); err != nil {
b.Fatal(err)
}
}
@@ -153,7 +152,7 @@ func BenchmarkBaseGocqlSelect(b *testing.B) {
initTable(b, session, people)
stmt, _ := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Limit(100).ToCql()
q := session.Query(stmt)
q := session.Session.Query(stmt)
defer q.Release()
b.ResetTimer()
@@ -179,8 +178,8 @@ func BenchmarkGocqlxSelect(b *testing.B) {
initTable(b, session, people)
stmt, _ := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Limit(100).ToCql()
q := gocqlx.Query(session.Query(stmt), nil)
stmt, names := qb.Select("gocqlx_test.bench_person").Columns(benchPersonCols...).Limit(100).ToCql()
q := session.Query(stmt, names)
defer q.Release()
b.ResetTimer()
@@ -211,13 +210,13 @@ func loadFixtures() []*benchPerson {
return v
}
func initTable(b *testing.B, session *gocql.Session, people []*benchPerson) {
if err := ExecStmt(session, benchPersonSchema); err != nil {
func initTable(b *testing.B, session gocqlx.Session, people []*benchPerson) {
if err := session.ExecStmt(benchPersonSchema); err != nil {
b.Fatal(err)
}
stmt, names := qb.Insert("gocqlx_test.bench_person").Columns(benchPersonCols...).ToCql()
q := gocqlx.Query(session.Query(stmt), names)
q := session.Query(stmt, names)
for _, p := range people {
if err := q.BindStruct(p).Exec(); err != nil {

View File

@@ -1,9 +1,26 @@
// Copyright (C) 2017 ScyllaDB
// Use of this source code is governed by a ALv2-style
// license that can be found in the LICENSE file.
package gocqlx_test
import (
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx"
"github.com/scylladb/gocqlx/qb"
)
func ExampleSession() {
cluster := gocql.NewCluster("host")
session, err := gocqlx.WrapSession(cluster.CreateSession())
if err != nil {
// handle error
}
builder := qb.Select("foo")
session.Query(builder.ToCql())
}
func ExampleUDT() {
// Just add gocqlx.UDT to a type, no need to implement marshalling functions
type FullName struct {

View File

@@ -27,7 +27,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
PRIMARY KEY(first_name, last_name)
)`
if err := ExecStmt(session, personSchema); err != nil {
if err := session.ExecStmt(personSchema); err != nil {
t.Fatal("create table:", err)
}
@@ -49,7 +49,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
// Insert, bind data from struct.
{
stmt, names := qb.Insert("gocqlx_test.person").Columns("first_name", "last_name", "email").ToCql()
q := gocqlx.Query(session.Query(stmt), names).BindStruct(p)
q := session.Query(stmt, names).BindStruct(p)
if err := q.ExecRelease(); err != nil {
t.Fatal(err)
@@ -63,9 +63,9 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
TTL(86400 * time.Second).
Timestamp(time.Now()).
ToCql()
q := session.Query(stmt, names).BindStruct(p)
err := gocqlx.Query(session.Query(stmt), names).BindStruct(p).ExecRelease()
if err != nil {
if err := q.ExecRelease(); err != nil {
t.Fatal(err)
}
}
@@ -78,7 +78,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
Set("email").
Where(qb.Eq("first_name"), qb.Eq("last_name")).
ToCql()
q := gocqlx.Query(session.Query(stmt), names).BindStruct(p)
q := session.Query(stmt, names).BindStruct(p)
if err := q.ExecRelease(); err != nil {
t.Fatal(err)
@@ -91,7 +91,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
AddNamed("email", "new_email").
Where(qb.Eq("first_name"), qb.Eq("last_name")).
ToCql()
q := gocqlx.Query(session.Query(stmt), names).BindStructMap(p, qb.M{
q := session.Query(stmt, names).BindStructMap(p, qb.M{
"new_email": []string{"patricia2.citzen@gocqlx_test.com", "patricia3.citzen@gocqlx_test.com"},
})
@@ -124,7 +124,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
[]string{"ian.citzen@gocqlx_test.com"},
},
}
q := gocqlx.Query(session.Query(stmt), names).BindStruct(&batch)
q := session.Query(stmt, names).BindStruct(&batch)
if err := q.ExecRelease(); err != nil {
t.Fatal(err)
@@ -136,7 +136,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
var p Person
stmt, names := qb.Select("gocqlx_test.person").Where(qb.Eq("first_name")).ToCql()
q := gocqlx.Query(session.Query(stmt), names).BindMap(qb.M{
q := session.Query(stmt, names).BindMap(qb.M{
"first_name": "Patricia",
})
@@ -153,7 +153,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
var people []Person
stmt, names := qb.Select("gocqlx_test.person").Where(qb.In("first_name")).ToCql()
q := gocqlx.Query(session.Query(stmt), names).BindMap(qb.M{
q := session.Query(stmt, names).BindMap(qb.M{
"first_name": []string{"Patricia", "Igy", "Ian"},
})
@@ -178,7 +178,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
Where(qb.Token("first_name").Gt()).
Limit(10).
ToCql()
q := gocqlx.Query(session.Query(stmt), names).BindStruct(p)
q := session.Query(stmt, names).BindStruct(p)
var people []Person
if err := q.SelectRelease(&people); err != nil {
@@ -202,7 +202,8 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
"Citizen",
[]string{"jane.citzen@gocqlx_test.com"},
}
q := gocqlx.Query(session.Query(stmt), names).BindStruct(p)
q := session.Query(stmt, names).BindStruct(p)
if err := q.ExecRelease(); err != nil {
t.Fatal(err)
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx"
)
var (
@@ -28,7 +29,7 @@ var (
var initOnce sync.Once
// CreateSession creates a new gocql session from flags.
func CreateSession(tb testing.TB) *gocql.Session {
func CreateSession(tb testing.TB) gocqlx.Session {
cluster := createCluster()
return createSessionFromCluster(cluster, tb)
}
@@ -60,7 +61,7 @@ func createCluster() *gocql.ClusterConfig {
return cluster
}
func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocql.Session {
func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) gocqlx.Session {
// Drop and re-create the keyspace once. Different tests should use their own
// individual tables, but can assume that the table does not exist before.
initOnce.Do(func() {
@@ -68,11 +69,10 @@ func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocq
})
cluster.Keyspace = "gocqlx_test"
session, err := cluster.CreateSession()
session, err := gocqlx.WrapSession(cluster.CreateSession())
if err != nil {
tb.Fatal("CreateSession:", err)
}
return session
}
@@ -80,31 +80,23 @@ func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string
c := *cluster
c.Keyspace = "system"
c.Timeout = 30 * time.Second
session, err := c.CreateSession()
session, err := gocqlx.WrapSession(c.CreateSession())
if err != nil {
tb.Fatal(err)
}
defer session.Close()
err = ExecStmt(session, `DROP KEYSPACE IF EXISTS `+keyspace)
err = session.ExecStmt(`DROP KEYSPACE IF EXISTS ` + keyspace)
if err != nil {
tb.Fatalf("unable to drop keyspace: %v", err)
}
err = ExecStmt(session, fmt.Sprintf(`CREATE KEYSPACE %s
err = session.ExecStmt(fmt.Sprintf(`CREATE KEYSPACE %s
WITH replication = {
'class' : 'SimpleStrategy',
'replication_factor' : %d
}`, keyspace, *flagRF))
if err != nil {
tb.Fatalf("unable to create keyspace: %v", err)
}
}
// ExecStmt executes a statement string.
func ExecStmt(s *gocql.Session, stmt string) error {
q := s.Query(stmt).RetryPolicy(nil)
defer q.Release()
return q.Exec()
}

View File

@@ -31,15 +31,6 @@ type Iterx struct {
values []interface{}
}
// Iter creates a new Iterx from gocql.Query using a default mapper.
func Iter(q *gocql.Query) *Iterx {
return &Iterx{
Iter: q.Iter(),
Mapper: DefaultMapper,
unsafe: DefaultUnsafe,
}
}
// Unsafe forces the iterator to ignore missing fields. By default when scanning
// a struct if result row has a column that cannot be mapped to any destination
// field an error is reported. With unsafe such columns are ignored.

View File

@@ -49,11 +49,11 @@ func TestStruct(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TYPE gocqlx_test.FullName (first_Name text, last_name text)`); err != nil {
if err := session.ExecStmt(`CREATE TYPE gocqlx_test.FullName (first_Name text, last_name text)`); err != nil {
t.Fatal("create type:", err)
}
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.struct_table (
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.struct_table (
testuuid timeuuid PRIMARY KEY,
testtimestamp timestamp,
testvarchar varchar,
@@ -123,7 +123,9 @@ func TestStruct(t *testing.T) {
Testptrudt: FullNamePtrUDT{FullName: &FullName{FirstName: "John", LastName: "Doe"}},
}
if err := gocqlx.Query(session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt, testptrudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`), nil).Bind(
const stmt = `INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt, testptrudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
if err := session.Query(stmt, nil).Bind(
m.Testuuid,
m.Testtimestamp,
m.Testvarchar,
@@ -141,13 +143,13 @@ func TestStruct(t *testing.T) {
m.Testinet,
m.Testcustom,
m.Testudt,
m.Testptrudt).Exec(); err != nil {
m.Testptrudt).ExecRelease(); err != nil {
t.Fatal("insert:", err)
}
t.Run("get", func(t *testing.T) {
var v StructTable
if err := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Get(&v); err != nil {
if err := session.Query(`SELECT * FROM struct_table`, nil).Get(&v); err != nil {
t.Fatal("get failed", err)
}
@@ -158,7 +160,7 @@ func TestStruct(t *testing.T) {
t.Run("select", func(t *testing.T) {
var v []StructTable
if err := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Select(&v); err != nil {
if err := session.Query(`SELECT * FROM struct_table`, nil).Select(&v); err != nil {
t.Fatal("select failed", err)
}
@@ -173,7 +175,7 @@ func TestStruct(t *testing.T) {
t.Run("select ptr", func(t *testing.T) {
var v []*StructTable
if err := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Select(&v); err != nil {
if err := session.Query(`SELECT * FROM struct_table`, nil).Select(&v); err != nil {
t.Fatal("select failed", err)
}
@@ -192,7 +194,7 @@ func TestStruct(t *testing.T) {
n int
)
i := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Iter()
i := session.Query(`SELECT * FROM struct_table`, nil).Iter()
for i.StructScan(&v) {
n++
}
@@ -212,18 +214,18 @@ func TestStruct(t *testing.T) {
func TestScannable(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.scannable_table (testfullname text PRIMARY KEY)`); err != nil {
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.scannable_table (testfullname text PRIMARY KEY)`); err != nil {
t.Fatal("create table:", err)
}
m := FullName{"John", "Doe"}
if err := session.Query(`INSERT INTO scannable_table (testfullname) values (?)`, m).Exec(); err != nil {
if err := session.Query(`INSERT INTO scannable_table (testfullname) values (?)`, nil).Bind(m).Exec(); err != nil {
t.Fatal("insert:", err)
}
t.Run("get", func(t *testing.T) {
var v FullName
if err := gocqlx.Query(session.Query(`SELECT testfullname FROM scannable_table`), nil).Get(&v); err != nil {
if err := session.Query(`SELECT testfullname FROM scannable_table`, nil).Get(&v); err != nil {
t.Fatal("get failed", err)
}
@@ -234,7 +236,7 @@ func TestScannable(t *testing.T) {
t.Run("select", func(t *testing.T) {
var v []FullName
if err := gocqlx.Query(session.Query(`SELECT testfullname FROM scannable_table`), nil).Select(&v); err != nil {
if err := session.Query(`SELECT testfullname FROM scannable_table`, nil).Select(&v); err != nil {
t.Fatal("select failed", err)
}
@@ -249,7 +251,7 @@ func TestScannable(t *testing.T) {
t.Run("select ptr", func(t *testing.T) {
var v []*FullName
if err := gocqlx.Query(session.Query(`SELECT testfullname FROM scannable_table`), nil).Select(&v); err != nil {
if err := session.Query(`SELECT testfullname FROM scannable_table`, nil).Select(&v); err != nil {
t.Fatal("select failed", err)
}
@@ -266,19 +268,19 @@ func TestScannable(t *testing.T) {
func TestStructOnly(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.struct_only_table (first_name text, last_name text, PRIMARY KEY (first_name, last_name))`); err != nil {
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.struct_only_table (first_name text, last_name text, PRIMARY KEY (first_name, last_name))`); err != nil {
t.Fatal("create table:", err)
}
m := FullName{"John", "Doe"}
if err := session.Query(`INSERT INTO struct_only_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil {
if err := session.Query(`INSERT INTO struct_only_table (first_name, last_name) values (?, ?)`, nil).Bind(m.FirstName, m.LastName).Exec(); err != nil {
t.Fatal("insert:", err)
}
t.Run("get", func(t *testing.T) {
var v FullName
if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_table`)).StructOnly().Get(&v); err != nil {
if err := session.Query(`SELECT first_name, last_name FROM struct_only_table`, nil).Iter().StructOnly().Get(&v); err != nil {
t.Fatal("get failed", err)
}
@@ -289,7 +291,7 @@ func TestStructOnly(t *testing.T) {
t.Run("select", func(t *testing.T) {
var v []FullName
if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_table`)).StructOnly().Select(&v); err != nil {
if err := session.Query(`SELECT first_name, last_name FROM struct_only_table`, nil).Iter().StructOnly().Select(&v); err != nil {
t.Fatal("select failed", err)
}
@@ -304,7 +306,7 @@ func TestStructOnly(t *testing.T) {
t.Run("select ptr", func(t *testing.T) {
var v []*FullName
if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_table`)).StructOnly().Select(&v); err != nil {
if err := session.Query(`SELECT first_name, last_name FROM struct_only_table`, nil).Iter().StructOnly().Select(&v); err != nil {
t.Fatal("select failed", err)
}
@@ -319,7 +321,7 @@ func TestStructOnly(t *testing.T) {
t.Run("get error", func(t *testing.T) {
var v FullName
err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_table`)).Get(&v)
err := session.Query(`SELECT first_name, last_name FROM struct_only_table`, nil).Get(&v)
if err == nil || !strings.HasPrefix(err.Error(), "expected 1 column in result") {
t.Fatal("get expected validation error got", err)
}
@@ -327,7 +329,7 @@ func TestStructOnly(t *testing.T) {
t.Run("select error", func(t *testing.T) {
var v []FullName
err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_table`)).Select(&v)
err := session.Query(`SELECT first_name, last_name FROM struct_only_table`, nil).Select(&v)
if err == nil || !strings.HasPrefix(err.Error(), "expected 1 column in result") {
t.Fatal("select expected validation error got", err)
}
@@ -337,7 +339,7 @@ func TestStructOnly(t *testing.T) {
func TestStructOnlyUDT(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.struct_only_udt_table (first_name text, last_name text, PRIMARY KEY (first_name, last_name))`); err != nil {
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.struct_only_udt_table (first_name text, last_name text, PRIMARY KEY (first_name, last_name))`); err != nil {
t.Fatal("create table:", err)
}
@@ -348,13 +350,13 @@ func TestStructOnlyUDT(t *testing.T) {
},
}
if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil {
if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, nil).Bind(m.FirstName, m.LastName).Exec(); err != nil {
t.Fatal("insert:", err)
}
t.Run("get", func(t *testing.T) {
var v FullNameUDT
if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`)).StructOnly().Get(&v); err != nil {
if err := session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`, nil).Iter().StructOnly().Get(&v); err != nil {
t.Fatal("get failed", err)
}
@@ -365,7 +367,7 @@ func TestStructOnlyUDT(t *testing.T) {
t.Run("select", func(t *testing.T) {
var v []FullNameUDT
if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`)).StructOnly().Select(&v); err != nil {
if err := session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`, nil).Iter().StructOnly().Select(&v); err != nil {
t.Fatal("select failed", err)
}
@@ -380,7 +382,7 @@ func TestStructOnlyUDT(t *testing.T) {
t.Run("select ptr", func(t *testing.T) {
var v []*FullNameUDT
if err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`)).StructOnly().Select(&v); err != nil {
if err := session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`, nil).Iter().StructOnly().Select(&v); err != nil {
t.Fatal("select failed", err)
}
@@ -395,7 +397,7 @@ func TestStructOnlyUDT(t *testing.T) {
t.Run("get error", func(t *testing.T) {
var v FullNameUDT
err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`)).Get(&v)
err := session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`, nil).Get(&v)
if err == nil || !strings.HasPrefix(err.Error(), "expected 1 column in result") {
t.Fatal("get expected validation error got", err)
}
@@ -403,7 +405,7 @@ func TestStructOnlyUDT(t *testing.T) {
t.Run("select error", func(t *testing.T) {
var v []FullNameUDT
err := gocqlx.Iter(session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`)).Select(&v)
err := session.Query(`SELECT first_name, last_name FROM struct_only_udt_table`, nil).Select(&v)
if err == nil || !strings.HasPrefix(err.Error(), "expected 1 column in result") {
t.Fatal("select expected validation error got", err)
}
@@ -413,10 +415,10 @@ func TestStructOnlyUDT(t *testing.T) {
func TestUnsafe(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.unsafe_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil {
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.unsafe_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil {
t.Fatal("create table:", err)
}
if err := session.Query(`INSERT INTO unsafe_table (testtext, testtextunbound) values (?, ?)`, "test", "test").Exec(); err != nil {
if err := session.Query(`INSERT INTO unsafe_table (testtext, testtextunbound) values (?, ?)`, nil).Bind("test", "test").Exec(); err != nil {
t.Fatal("insert:", err)
}
@@ -426,16 +428,16 @@ func TestUnsafe(t *testing.T) {
t.Run("safe get", func(t *testing.T) {
var v UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
err := session.Query(`SELECT * FROM unsafe_table`, nil).Get(&v)
if err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
t.Fatal("expected ErrNotFound", "got", err)
}
})
t.Run("safe select", func(t *testing.T) {
var v []UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
err := session.Query(`SELECT * FROM unsafe_table`, nil).Select(&v)
if err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
t.Fatal("expected ErrNotFound", "got", err)
}
if cap(v) > 0 {
@@ -445,8 +447,8 @@ func TestUnsafe(t *testing.T) {
t.Run("unsafe get", func(t *testing.T) {
var v UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Unsafe().Get(&v); err != nil {
err := session.Query(`SELECT * FROM unsafe_table`, nil).Iter().Unsafe().Get(&v)
if err != nil {
t.Fatal(err)
}
if v.Testtext != "test" {
@@ -456,8 +458,8 @@ func TestUnsafe(t *testing.T) {
t.Run("unsafe select", func(t *testing.T) {
var v []UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Unsafe().Select(&v); err != nil {
err := session.Query(`SELECT * FROM unsafe_table`, nil).Iter().Unsafe().Select(&v)
if err != nil {
t.Fatal(err)
}
if len(v) != 1 {
@@ -470,10 +472,12 @@ func TestUnsafe(t *testing.T) {
t.Run("DefaultUnsafe select", func(t *testing.T) {
gocqlx.DefaultUnsafe = true
defer func() { gocqlx.DefaultUnsafe = false }()
defer func() {
gocqlx.DefaultUnsafe = false
}()
var v []UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Select(&v); err != nil {
err := session.Query(`SELECT * FROM unsafe_table`, nil).Iter().Select(&v)
if err != nil {
t.Fatal(err)
}
if len(v) != 1 {
@@ -488,7 +492,7 @@ func TestUnsafe(t *testing.T) {
func TestNotFound(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.not_found_table (testtext text PRIMARY KEY)`); err != nil {
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.not_found_table (testtext text PRIMARY KEY)`); err != nil {
t.Fatal("create table:", err)
}
@@ -498,9 +502,7 @@ func TestNotFound(t *testing.T) {
t.Run("get cql error", func(t *testing.T) {
var v NotFoundTable
i := gocqlx.Iter(session.Query(`SELECT * FROM not_found_table WRONG`).RetryPolicy(nil))
err := i.Get(&v)
err := session.Query(`SELECT * FROM not_found_table WRONG`, nil).RetryPolicy(nil).Get(&v)
if err == nil || !strings.Contains(err.Error(), "WRONG") {
t.Fatal(err)
}
@@ -508,17 +510,15 @@ func TestNotFound(t *testing.T) {
t.Run("get", func(t *testing.T) {
var v NotFoundTable
i := gocqlx.Iter(session.Query(`SELECT * FROM not_found_table`))
if err := i.Get(&v); err != gocql.ErrNotFound {
err := session.Query(`SELECT * FROM not_found_table`, nil).Get(&v)
if err != gocql.ErrNotFound {
t.Fatal("expected ErrNotFound", "got", err)
}
})
t.Run("select cql error", func(t *testing.T) {
var v []NotFoundTable
i := gocqlx.Iter(session.Query(`SELECT * FROM not_found_table WRONG`).RetryPolicy(nil))
err := i.Select(&v)
err := session.Query(`SELECT * FROM not_found_table WRONG`, nil).RetryPolicy(nil).Select(&v)
if err == nil || !strings.Contains(err.Error(), "WRONG") {
t.Fatal(err)
}
@@ -526,8 +526,8 @@ func TestNotFound(t *testing.T) {
t.Run("select", func(t *testing.T) {
var v []NotFoundTable
i := gocqlx.Iter(session.Query(`SELECT * FROM not_found_table`))
if err := i.Select(&v); err != nil {
err := session.Query(`SELECT * FROM not_found_table`, nil).Select(&v)
if err != nil {
t.Fatal(err)
}
if cap(v) > 0 {
@@ -539,7 +539,7 @@ func TestNotFound(t *testing.T) {
func TestErrorOnNil(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.nil_table (testtext text PRIMARY KEY)`); err != nil {
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.nil_table (testtext text PRIMARY KEY)`); err != nil {
t.Fatal("create table:", err)
}
@@ -549,19 +549,19 @@ func TestErrorOnNil(t *testing.T) {
)
t.Run("get", func(t *testing.T) {
err := gocqlx.Iter(session.Query(stmt)).Get(nil)
err := session.Query(stmt, nil).Get(nil)
if err == nil || err.Error() != golden {
t.Fatalf("Get()=%q expected %q error", err, golden)
}
})
t.Run("select", func(t *testing.T) {
err := gocqlx.Iter(session.Query(stmt)).Select(nil)
err := session.Query(stmt, nil).Select(nil)
if err == nil || err.Error() != golden {
t.Fatalf("Select()=%q expected %q error", err, golden)
}
})
t.Run("struct scan", func(t *testing.T) {
i := gocqlx.Iter(session.Query(stmt))
i := session.Query(stmt, nil).Iter()
i.StructScan(nil)
err := i.Close()
if err == nil || err.Error() != golden {
@@ -573,14 +573,14 @@ func TestErrorOnNil(t *testing.T) {
func TestPaging(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.paging_table (id int PRIMARY KEY, val int)`); err != nil {
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.paging_table (id int PRIMARY KEY, val int)`); err != nil {
t.Fatal("create table:", err)
}
if err := ExecStmt(session, `CREATE INDEX id_val_index ON gocqlx_test.paging_table (val)`); err != nil {
if err := session.ExecStmt(`CREATE INDEX id_val_index ON gocqlx_test.paging_table (val)`); err != nil {
t.Fatal("create index:", err)
}
stmt, names := qb.Insert("gocqlx_test.paging_table").Columns("id", "val").ToCql()
q := gocqlx.Query(session.Query(stmt), names)
q := session.Query(qb.Insert("gocqlx_test.paging_table").Columns("id", "val").ToCql())
for i := 0; i < 5000; i++ {
if err := q.Bind(i, i).Exec(); err != nil {
t.Fatal(err)
@@ -597,12 +597,13 @@ func TestPaging(t *testing.T) {
Where(qb.Lt("val")).
AllowFiltering().
Columns("id", "val").ToCql()
it := gocqlx.Query(session.Query(stmt, 100).PageSize(10), names).Iter()
defer it.Close()
iter := session.Query(stmt, names).Bind(100).PageSize(10).Iter()
defer iter.Close()
var cnt int
for {
p := &Paging{}
if !it.StructScan(p) {
if !iter.StructScan(p) {
break
}
cnt++

View File

@@ -12,4 +12,6 @@ import (
// snake case. It can be set to whatever you want, but it is encouraged to be
// set before gocqlx is used as name-to-field mappings are cached after first
// use on a type.
//
// A custom mapper can always be set per Sessionm, Query and Iter.
var DefaultMapper = reflectx.NewMapperFunc("db", reflectx.CamelToSnakeASCII)

View File

@@ -7,7 +7,7 @@ package migrate
import (
"context"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx"
)
// CallbackEvent specifies type of the event when calling CallbackFunc.
@@ -21,7 +21,7 @@ const (
// CallbackFunc enables interrupting the migration process and executing code
// while migrating. If error is returned the migration is aborted.
type CallbackFunc func(ctx context.Context, session *gocql.Session, ev CallbackEvent, name string) error
type CallbackFunc func(ctx context.Context, session gocqlx.Session, ev CallbackEvent, name string) error
// Callback is called before and after each migration.
// See CallbackFunc for details.

View File

@@ -64,12 +64,12 @@ type Info struct {
}
// List provides a listing of applied migrations.
func List(ctx context.Context, session *gocql.Session) ([]*Info, error) {
func List(ctx context.Context, session gocqlx.Session) ([]*Info, error) {
if err := ensureInfoTable(ctx, session); err != nil {
return nil, err
}
q := gocqlx.Query(session.Query(selectInfo).WithContext(ctx), nil)
q := session.ContextQuery(ctx, selectInfo, nil)
var v []*Info
if err := q.SelectRelease(&v); err == gocql.ErrNotFound {
@@ -85,12 +85,12 @@ func List(ctx context.Context, session *gocql.Session) ([]*Info, error) {
return v, nil
}
func ensureInfoTable(ctx context.Context, session *gocql.Session) error {
return gocqlx.Query(session.Query(infoSchema).WithContext(ctx), nil).ExecRelease()
func ensureInfoTable(ctx context.Context, session gocqlx.Session) error {
return session.ContextQuery(ctx, infoSchema, nil).ExecRelease()
}
// Migrate reads the cql files from a directory and applies required migrations.
func Migrate(ctx context.Context, session *gocql.Session, dir string) error {
func Migrate(ctx context.Context, session gocqlx.Session, dir string) error {
// get database migrations
dbm, err := List(ctx, session)
if err != nil {
@@ -147,7 +147,7 @@ func Migrate(ctx context.Context, session *gocql.Session, dir string) error {
return nil
}
func applyMigration(ctx context.Context, session *gocql.Session, path string, done int) error {
func applyMigration(ctx context.Context, session gocqlx.Session, path string, done int) error {
f, err := os.Open(path)
if err != nil {
return err
@@ -173,8 +173,8 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do
"end_time",
).ToCql()
iq := gocqlx.Query(session.Query(stmt).WithContext(ctx), names)
defer iq.Release()
update := session.ContextQuery(ctx, stmt, names)
defer update.Release()
if DefaultAwaitSchemaAgreement.ShouldAwait(AwaitSchemaAgreementBeforeEachFile) {
if err = session.AwaitSchemaAgreement(ctx); err != nil {
@@ -216,7 +216,7 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do
}
// execute
q := gocqlx.Query(session.Query(stmt).RetryPolicy(nil).WithContext(ctx), nil)
q := session.ContextQuery(ctx, stmt, nil).RetryPolicy(nil)
if err := q.ExecRelease(); err != nil {
return fmt.Errorf("statement %d failed: %s", i, err)
}
@@ -224,7 +224,7 @@ func applyMigration(ctx context.Context, session *gocql.Session, path string, do
// update info
info.Done = i
info.EndTime = time.Now()
if err := iq.BindStruct(info).Exec(); err != nil {
if err := update.BindStruct(info).Exec(); err != nil {
return fmt.Errorf("migration statement %d failed: %s", i, err)
}
}

View File

@@ -15,7 +15,7 @@ import (
"strings"
"testing"
"github.com/gocql/gocql"
"github.com/scylladb/gocqlx"
. "github.com/scylladb/gocqlx/gocqlxtest"
"github.com/scylladb/gocqlx/migrate"
)
@@ -30,16 +30,16 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.migrate_table (
var insertMigrate = `INSERT INTO gocqlx_test.migrate_table (testint, testuuid) VALUES (%d, now())`
func recreateTables(tb testing.TB, session *gocql.Session) {
func recreateTables(tb testing.TB, session gocqlx.Session) {
tb.Helper()
if err := ExecStmt(session, "DROP TABLE IF EXISTS gocqlx_test.gocqlx_migrate"); err != nil {
if err := session.ExecStmt("DROP TABLE IF EXISTS gocqlx_test.gocqlx_migrate"); err != nil {
tb.Fatal(err)
}
if err := ExecStmt(session, migrateSchema); err != nil {
if err := session.ExecStmt(migrateSchema); err != nil {
tb.Fatal(err)
}
if err := ExecStmt(session, "TRUNCATE gocqlx_test.migrate_table"); err != nil {
if err := session.ExecStmt("TRUNCATE gocqlx_test.migrate_table"); err != nil {
tb.Fatal(err)
}
}
@@ -105,7 +105,7 @@ func TestMigrationNoSemicolon(t *testing.T) {
defer session.Close()
recreateTables(t, session)
if err := ExecStmt(session, migrateSchema); err != nil {
if err := session.ExecStmt(migrateSchema); err != nil {
t.Fatal(err)
}
@@ -134,7 +134,7 @@ func TestMigrationCallback(t *testing.T) {
beforeCalled int
afterCalled int
)
migrate.Callback = func(ctx context.Context, session *gocql.Session, ev migrate.CallbackEvent, name string) error {
migrate.Callback = func(ctx context.Context, session gocqlx.Session, ev migrate.CallbackEvent, name string) error {
switch ev {
case migrate.BeforeMigration:
beforeCalled += 1
@@ -166,7 +166,7 @@ func TestMigrationCallback(t *testing.T) {
defer session.Close()
recreateTables(t, session)
if err := ExecStmt(session, migrateSchema); err != nil {
if err := session.ExecStmt(migrateSchema); err != nil {
t.Fatal(err)
}
@@ -215,14 +215,11 @@ func makeMigrationDir(tb testing.TB, n int) (dir string) {
return dir
}
func countMigrations(tb testing.TB, session *gocql.Session) int {
func countMigrations(tb testing.TB, session gocqlx.Session) int {
tb.Helper()
q := session.Query("SELECT COUNT(*) FROM gocqlx_test.migrate_table")
defer q.Release()
var v int
if err := q.Scan(&v); err != nil {
if err := session.Query("SELECT COUNT(*) FROM gocqlx_test.migrate_table", nil).Get(&v); err != nil {
tb.Fatal(err)
}
return v

View File

@@ -90,6 +90,8 @@ type Queryx struct {
}
// Query creates a new Queryx from gocql.Query using a default mapper.
//
// Deprecated: Use Session API instead.
func Query(q *gocql.Query, names []string) *Queryx {
return &Queryx{
Query: q,
@@ -266,7 +268,9 @@ func (q *Queryx) SelectRelease(dest interface{}) error {
// big to be loaded with Select in order to do row by row iteration.
// See Iterx StructScan function.
func (q *Queryx) Iter() *Iterx {
i := Iter(q.Query)
i.Mapper = q.Mapper
return i
return &Iterx{
Iter: q.Query.Iter(),
Mapper: q.Mapper,
unsafe: DefaultUnsafe,
}
}

60
session.go Normal file
View File

@@ -0,0 +1,60 @@
// Copyright (C) 2017 ScyllaDB
// Use of this source code is governed by a ALv2-style
// license that can be found in the LICENSE file.
package gocqlx
import (
"context"
"github.com/gocql/gocql"
"github.com/scylladb/go-reflectx"
)
// Session wraps gocql.Session and provides a modified Query function that
// returns Queryx instance.
// The original Session instance can be accessed as Session.
// The default mapper uses `db` tag and automatically converts struct field
// names to snake case. If needed package reflectx provides constructors
// for other types of mappers.
type Session struct {
*gocql.Session
Mapper *reflectx.Mapper
}
// WrapSession should be called on CreateSession() gocql function to convert
// the created session to gocqlx.Session.
func WrapSession(session *gocql.Session, err error) (Session, error) {
return Session{
Session: session,
Mapper: DefaultMapper,
}, err
}
// ContextQuery is a helper function that allows to pass context when creating
// a query, see the "Query" function .
func (s Session) ContextQuery(ctx context.Context, stmt string, names []string) *Queryx {
return &Queryx{
Query: s.Session.Query(stmt).WithContext(ctx),
Names: names,
Mapper: s.Mapper,
}
}
// Query creates a new Queryx using the session mapper.
// The stmt and names parameters are typically result of a query builder
// (package qb) ToCql() function or come from table model (package table).
// The names parameter is a list of query parameters' names and it's used for
// binding.
func (s Session) Query(stmt string, names []string) *Queryx {
return &Queryx{
Query: s.Session.Query(stmt),
Names: names,
Mapper: s.Mapper,
}
}
// ExecStmt creates query and executes the given statement.
func (s Session) ExecStmt(stmt string) error {
return s.Query(stmt, nil).ExecRelease()
}

View File

@@ -9,7 +9,6 @@ package table_test
import (
"testing"
"github.com/scylladb/gocqlx"
. "github.com/scylladb/gocqlx/gocqlxtest"
"github.com/scylladb/gocqlx/qb"
"github.com/scylladb/gocqlx/table"
@@ -26,7 +25,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
email list<text>,
PRIMARY KEY(first_name, last_name)
)`
if err := ExecStmt(session, personSchema); err != nil {
if err := session.ExecStmt(personSchema); err != nil {
t.Fatal("create table:", err)
}
@@ -58,8 +57,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
[]string{"patricia.citzen@gocqlx_test.com"},
}
stmt, names := personTable.Insert()
q := gocqlx.Query(session.Query(stmt), names).BindStruct(p)
q := session.Query(personTable.Insert()).BindStruct(p)
if err := q.ExecRelease(); err != nil {
t.Fatal(err)
}
@@ -73,8 +71,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
nil, // no email
}
stmt, names := personTable.Get() // you can filter columns too
q := gocqlx.Query(session.Query(stmt), names).BindStruct(p)
q := session.Query(personTable.Get()).BindStruct(p)
if err := q.GetRelease(&p); err != nil {
t.Fatal(err)
}
@@ -87,9 +84,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
{
var people []Person
stmt, names := personTable.Select() // you can filter columns too
q := gocqlx.Query(session.Query(stmt), names).BindMap(qb.M{"first_name": "Patricia"})
q := session.Query(personTable.Select()).BindMap(qb.M{"first_name": "Patricia"})
if err := q.SelectRelease(&people); err != nil {
t.Fatal(err)
}