From 52c5f6873a06b8e00205d89ef271ac7255c74e88 Mon Sep 17 00:00:00 2001 From: Maciej Zimnoch Date: Mon, 20 Apr 2020 19:43:32 +0200 Subject: [PATCH] queryx: Added CAS functions (#98) Added ExecCAS, ExecCASRelease, GetCAS, GetCASRelease functions suitable for INSERT ... IF NOT EXISTS and UPDATE's containing IF statement. Functions returns information wheter query was applied or not, together with pre-image. Fixes #98 --- example_test.go | 47 +++++++++++++++++++++++++++++++++ iterx.go | 13 ++++++--- iterx_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++++ queryx.go | 37 ++++++++++++++++++++++++++ 4 files changed, 164 insertions(+), 3 deletions(-) diff --git a/example_test.go b/example_test.go index 224d754..8ffeb3e 100644 --- a/example_test.go +++ b/example_test.go @@ -24,6 +24,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( first_name text, last_name text, email list, + salary int, PRIMARY KEY(first_name, last_name) )` @@ -38,12 +39,14 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( FirstName string LastName string Email []string + Salary int } p := Person{ "Patricia", "Citizen", []string{"patricia.citzen@gocqlx_test.com"}, + 500, } // Insert, bind data from struct. @@ -117,11 +120,13 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( "Igy", "Citizen", []string{"igy.citzen@gocqlx_test.com"}, + 500, }, B: Person{ "Ian", "Citizen", []string{"ian.citzen@gocqlx_test.com"}, + 500, }, } q := session.Query(stmt, names).BindStruct(&batch) @@ -171,6 +176,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( "Ian", "Citizen", []string{"ian.citzen@gocqlx_test.com"}, + 500, } stmt, names := qb.Select("gocqlx_test.person"). @@ -201,6 +207,7 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( "Jane", "Citizen", []string{"jane.citzen@gocqlx_test.com"}, + 500, } q := session.Query(stmt, names).BindStruct(p) @@ -208,4 +215,44 @@ CREATE TABLE IF NOT EXISTS gocqlx_test.person ( t.Fatal(err) } } + + // Support for Lightweight Transactions + { + + p := Person{ + "Stephen", + "Johns", + []string{"stephen.johns@gocqlx_test.com"}, + 500, + } + + stmt, names := qb.Insert("gocqlx_test.person"). + Columns("first_name", "last_name", "email", "salary"). + Unique(). + ToCql() + + applied, err := session.Query(stmt, names).BindStruct(p).ExecCASRelease() + if err != nil { + t.Fatal(err) + } + + t.Log(applied) + + stmt, names = qb.Update("gocqlx_test.person"). + SetNamed("salary", "new_salary"). + Where(qb.Eq("first_name"), qb.Eq("last_name")). + If(qb.LtNamed("salary", "old_salary")). + ToCql() + q := session.Query(stmt, names).BindStructMap(&p, qb.M{ + "old_salary": 1000, + "new_salary": 1500, + }) + + applied, err = q.GetCAS(&p) + if err != nil { + t.Fatal(err) + } + + t.Log(applied, p) + } } diff --git a/iterx.go b/iterx.go index 52be4a1..54f37dd 100644 --- a/iterx.go +++ b/iterx.go @@ -24,6 +24,7 @@ type Iterx struct { unsafe bool structOnly bool + applied bool err error // Cache memory for a rows during iteration in structScan. @@ -252,6 +253,8 @@ func (iter *Iterx) StructScan(dest interface{}) bool { return iter.structScan(value) } +const appliedColumn = "[applied]" + func (iter *Iterx) structScan(value reflect.Value) bool { if value.Kind() != reflect.Ptr { panic("value must be a pointer") @@ -259,16 +262,20 @@ func (iter *Iterx) structScan(value reflect.Value) bool { if iter.fields == nil { columns := columnNames(iter.Iter.Columns()) - iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns) + cas := len(columns) > 0 && columns[0] == appliedColumn - // if we are not unsafe and are missing fields, return an error - if !iter.unsafe { + iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns) + // if we are not unsafe and it's not CAS query and are missing fields, return an error + if !iter.unsafe && !cas { if f, err := missingFields(iter.fields); err != nil { iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type()) return false } } iter.values = make([]interface{}, len(columns)) + if cas { + iter.values[0] = &iter.applied + } } if err := iter.fieldsByTraversal(value, iter.fields, iter.values); err != nil { diff --git a/iterx_test.go b/iterx_test.go index 546991d..8705dcf 100644 --- a/iterx_test.go +++ b/iterx_test.go @@ -622,3 +622,73 @@ func TestIterxPaging(t *testing.T) { t.Fatal("expected 100", "got", cnt) } } + +func TestIterx_CASInsertAndUpdates(t *testing.T) { + session := CreateSession(t) + defer session.Close() + + const ( + id = 0 + baseSalary = 1000 + minSalary = 2000 + ) + + john := struct { + ID int + Salary int + }{ID: id, Salary: baseSalary} + + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.salaries (id int PRIMARY KEY, salary int)`); err != nil { + t.Fatal("create table:", err) + } + + insertQ := session.Query(qb.Insert("gocqlx_test.salaries").Columns("id", "salary").Unique().ToCql()) + applied, err := insertQ.BindStruct(john).ExecCAS() + if err != nil { + t.Fatal(err) + } + if !applied { + t.Error("Expected first insert success") + } + + applied, err = insertQ.BindStruct(john).ExecCASRelease() + if err != nil { + t.Fatal(err) + } + if applied { + t.Error("Expected second insert to not be applied") + } + + updateQ := session.Query(qb.Update("gocqlx_test.salaries"). + SetNamed("salary", "min_salary"). + Where(qb.Eq("id")). + If(qb.LtNamed("salary", "min_salary")). + ToCql(), + ) + + applied, err = updateQ.BindStructMap(john, qb.M{ + "min_salary": minSalary, + }).GetCAS(&john) + if err != nil { + t.Fatal(err) + } + if !applied { + t.Error("Expected update to be applied") + } + if john.Salary != baseSalary { + t.Error("Expected to have pre-image in struct after GetCAS") + } + + applied, err = updateQ.BindStructMap(john, qb.M{ + "min_salary": minSalary * 2, + }).GetCASRelease(&john) + if err != nil { + t.Fatal(err) + } + if !applied { + t.Error("Expected update to be applied") + } + if john.Salary != minSalary { + t.Error("Expected to have pre-image in struct after GetCAS") + } +} diff --git a/queryx.go b/queryx.go index cc5d85a..0a1f715 100644 --- a/queryx.go +++ b/queryx.go @@ -209,6 +209,23 @@ func (q *Queryx) ExecRelease() error { return q.Exec() } +// ExecCAS executes the Lightweight Transaction query, returns whether query was applied. +// See: https://docs.scylladb.com/using-scylla/lwt/ for more details. +func (q *Queryx) ExecCAS() (applied bool, err error) { + iter := q.Iter().StructOnly() + if err := iter.Get(&struct{}{}); err != nil { + return false, err + } + return iter.applied, iter.Close() +} + +// ExecCASRelease calls ExecCAS and releases the query, a released query cannot be +// reused. +func (q *Queryx) ExecCASRelease() (bool, error) { + defer q.Release() + return q.ExecCAS() +} + // Get scans first row into a destination and closes the iterator. // // If the destination type is a struct pointer, then Iter.StructScan will be @@ -236,6 +253,26 @@ func (q *Queryx) GetRelease(dest interface{}) error { return q.Get(dest) } +// GetCAS executes a lightweight transaction. +// If the transaction fails because the existing values did not match, +// the previous values will be stored in dest object. +// See: https://docs.scylladb.com/using-scylla/lwt/ for more details. +func (q *Queryx) GetCAS(dest interface{}) (applied bool, err error) { + iter := q.Iter() + if err := iter.Get(dest); err != nil { + return false, err + } + + return iter.applied, iter.Close() +} + +// GetCASRelease calls GetCAS and releases the query, a released query cannot be +// reused. +func (q *Queryx) GetCASRelease(dest interface{}) (bool, error) { + defer q.Release() + return q.GetCAS(dest) +} + // Select scans all rows into a destination, which must be a pointer to slice // of any type, and closes the iterator. //