queryx: Added CAS functions (#98)

Added ExecCAS, ExecCASRelease, GetCAS, GetCASRelease functions
suitable for INSERT ... IF NOT EXISTS and UPDATE's containing IF statement.
Functions returns information wheter query was applied or not, together
with pre-image.

Fixes #98
This commit is contained in:
Maciej Zimnoch
2020-04-20 19:43:32 +02:00
committed by Michal Jan Matczuk
parent 9655ae5b49
commit 52c5f6873a
4 changed files with 164 additions and 3 deletions

View File

@@ -24,6 +24,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
first_name text, first_name text,
last_name text, last_name text,
email list<text>, email list<text>,
salary int,
PRIMARY KEY(first_name, last_name) PRIMARY KEY(first_name, last_name)
)` )`
@@ -38,12 +39,14 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
FirstName string FirstName string
LastName string LastName string
Email []string Email []string
Salary int
} }
p := Person{ p := Person{
"Patricia", "Patricia",
"Citizen", "Citizen",
[]string{"patricia.citzen@gocqlx_test.com"}, []string{"patricia.citzen@gocqlx_test.com"},
500,
} }
// Insert, bind data from struct. // Insert, bind data from struct.
@@ -117,11 +120,13 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
"Igy", "Igy",
"Citizen", "Citizen",
[]string{"igy.citzen@gocqlx_test.com"}, []string{"igy.citzen@gocqlx_test.com"},
500,
}, },
B: Person{ B: Person{
"Ian", "Ian",
"Citizen", "Citizen",
[]string{"ian.citzen@gocqlx_test.com"}, []string{"ian.citzen@gocqlx_test.com"},
500,
}, },
} }
q := session.Query(stmt, names).BindStruct(&batch) q := session.Query(stmt, names).BindStruct(&batch)
@@ -171,6 +176,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
"Ian", "Ian",
"Citizen", "Citizen",
[]string{"ian.citzen@gocqlx_test.com"}, []string{"ian.citzen@gocqlx_test.com"},
500,
} }
stmt, names := qb.Select("gocqlx_test.person"). stmt, names := qb.Select("gocqlx_test.person").
@@ -201,6 +207,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
"Jane", "Jane",
"Citizen", "Citizen",
[]string{"jane.citzen@gocqlx_test.com"}, []string{"jane.citzen@gocqlx_test.com"},
500,
} }
q := session.Query(stmt, names).BindStruct(p) q := session.Query(stmt, names).BindStruct(p)
@@ -208,4 +215,44 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person (
t.Fatal(err) t.Fatal(err)
} }
} }
// Support for Lightweight Transactions
{
p := Person{
"Stephen",
"Johns",
[]string{"stephen.johns@gocqlx_test.com"},
500,
}
stmt, names := qb.Insert("gocqlx_test.person").
Columns("first_name", "last_name", "email", "salary").
Unique().
ToCql()
applied, err := session.Query(stmt, names).BindStruct(p).ExecCASRelease()
if err != nil {
t.Fatal(err)
}
t.Log(applied)
stmt, names = qb.Update("gocqlx_test.person").
SetNamed("salary", "new_salary").
Where(qb.Eq("first_name"), qb.Eq("last_name")).
If(qb.LtNamed("salary", "old_salary")).
ToCql()
q := session.Query(stmt, names).BindStructMap(&p, qb.M{
"old_salary": 1000,
"new_salary": 1500,
})
applied, err = q.GetCAS(&p)
if err != nil {
t.Fatal(err)
}
t.Log(applied, p)
}
} }

View File

@@ -24,6 +24,7 @@ type Iterx struct {
unsafe bool unsafe bool
structOnly bool structOnly bool
applied bool
err error err error
// Cache memory for a rows during iteration in structScan. // Cache memory for a rows during iteration in structScan.
@@ -252,6 +253,8 @@ func (iter *Iterx) StructScan(dest interface{}) bool {
return iter.structScan(value) return iter.structScan(value)
} }
const appliedColumn = "[applied]"
func (iter *Iterx) structScan(value reflect.Value) bool { func (iter *Iterx) structScan(value reflect.Value) bool {
if value.Kind() != reflect.Ptr { if value.Kind() != reflect.Ptr {
panic("value must be a pointer") panic("value must be a pointer")
@@ -259,16 +262,20 @@ func (iter *Iterx) structScan(value reflect.Value) bool {
if iter.fields == nil { if iter.fields == nil {
columns := columnNames(iter.Iter.Columns()) columns := columnNames(iter.Iter.Columns())
iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns) cas := len(columns) > 0 && columns[0] == appliedColumn
// if we are not unsafe and are missing fields, return an error iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns)
if !iter.unsafe { // if we are not unsafe and it's not CAS query and are missing fields, return an error
if !iter.unsafe && !cas {
if f, err := missingFields(iter.fields); err != nil { if f, err := missingFields(iter.fields); err != nil {
iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type()) iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type())
return false return false
} }
} }
iter.values = make([]interface{}, len(columns)) iter.values = make([]interface{}, len(columns))
if cas {
iter.values[0] = &iter.applied
}
} }
if err := iter.fieldsByTraversal(value, iter.fields, iter.values); err != nil { if err := iter.fieldsByTraversal(value, iter.fields, iter.values); err != nil {

View File

@@ -622,3 +622,73 @@ func TestIterxPaging(t *testing.T) {
t.Fatal("expected 100", "got", cnt) t.Fatal("expected 100", "got", cnt)
} }
} }
func TestIterx_CASInsertAndUpdates(t *testing.T) {
session := CreateSession(t)
defer session.Close()
const (
id = 0
baseSalary = 1000
minSalary = 2000
)
john := struct {
ID int
Salary int
}{ID: id, Salary: baseSalary}
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.salaries (id int PRIMARY KEY, salary int)`); err != nil {
t.Fatal("create table:", err)
}
insertQ := session.Query(qb.Insert("gocqlx_test.salaries").Columns("id", "salary").Unique().ToCql())
applied, err := insertQ.BindStruct(john).ExecCAS()
if err != nil {
t.Fatal(err)
}
if !applied {
t.Error("Expected first insert success")
}
applied, err = insertQ.BindStruct(john).ExecCASRelease()
if err != nil {
t.Fatal(err)
}
if applied {
t.Error("Expected second insert to not be applied")
}
updateQ := session.Query(qb.Update("gocqlx_test.salaries").
SetNamed("salary", "min_salary").
Where(qb.Eq("id")).
If(qb.LtNamed("salary", "min_salary")).
ToCql(),
)
applied, err = updateQ.BindStructMap(john, qb.M{
"min_salary": minSalary,
}).GetCAS(&john)
if err != nil {
t.Fatal(err)
}
if !applied {
t.Error("Expected update to be applied")
}
if john.Salary != baseSalary {
t.Error("Expected to have pre-image in struct after GetCAS")
}
applied, err = updateQ.BindStructMap(john, qb.M{
"min_salary": minSalary * 2,
}).GetCASRelease(&john)
if err != nil {
t.Fatal(err)
}
if !applied {
t.Error("Expected update to be applied")
}
if john.Salary != minSalary {
t.Error("Expected to have pre-image in struct after GetCAS")
}
}

View File

@@ -209,6 +209,23 @@ func (q *Queryx) ExecRelease() error {
return q.Exec() return q.Exec()
} }
// ExecCAS executes the Lightweight Transaction query, returns whether query was applied.
// See: https://docs.scylladb.com/using-scylla/lwt/ for more details.
func (q *Queryx) ExecCAS() (applied bool, err error) {
iter := q.Iter().StructOnly()
if err := iter.Get(&struct{}{}); err != nil {
return false, err
}
return iter.applied, iter.Close()
}
// ExecCASRelease calls ExecCAS and releases the query, a released query cannot be
// reused.
func (q *Queryx) ExecCASRelease() (bool, error) {
defer q.Release()
return q.ExecCAS()
}
// Get scans first row into a destination and closes the iterator. // Get scans first row into a destination and closes the iterator.
// //
// If the destination type is a struct pointer, then Iter.StructScan will be // If the destination type is a struct pointer, then Iter.StructScan will be
@@ -236,6 +253,26 @@ func (q *Queryx) GetRelease(dest interface{}) error {
return q.Get(dest) return q.Get(dest)
} }
// GetCAS executes a lightweight transaction.
// If the transaction fails because the existing values did not match,
// the previous values will be stored in dest object.
// See: https://docs.scylladb.com/using-scylla/lwt/ for more details.
func (q *Queryx) GetCAS(dest interface{}) (applied bool, err error) {
iter := q.Iter()
if err := iter.Get(dest); err != nil {
return false, err
}
return iter.applied, iter.Close()
}
// GetCASRelease calls GetCAS and releases the query, a released query cannot be
// reused.
func (q *Queryx) GetCASRelease(dest interface{}) (bool, error) {
defer q.Release()
return q.GetCAS(dest)
}
// Select scans all rows into a destination, which must be a pointer to slice // Select scans all rows into a destination, which must be a pointer to slice
// of any type, and closes the iterator. // of any type, and closes the iterator.
// //