add Batch wrapper with BindStruct method

This commit is contained in:
Gabriel Nelle
2023-03-03 16:14:23 +01:00
committed by Maciej Zimnoch
parent fc92258512
commit dec046bd85
3 changed files with 181 additions and 27 deletions

49
batchx.go Normal file
View File

@@ -0,0 +1,49 @@
package gocqlx
import (
"github.com/gocql/gocql"
)
type Batch struct {
*gocql.Batch
}
// NewBatch creates a new batch operation using defaults defined in the cluster.
func (s *Session) NewBatch(bt gocql.BatchType) *Batch {
return &Batch{
Batch: s.Session.NewBatch(bt),
}
}
// BindStruct binds query named parameters to values from arg using a mapper.
// If value cannot be found an error is reported.
func (b *Batch) BindStruct(qry *Queryx, arg interface{}) error {
args, err := qry.bindStructArgs(arg, nil)
if err != nil {
return err
}
b.Query(qry.Statement(), args...)
return nil
}
// ExecuteBatch executes a batch operation and returns nil if successful
// otherwise an error describing the failure.
func (s *Session) ExecuteBatch(batch *Batch) error {
return s.Session.ExecuteBatch(batch.Batch)
}
// ExecuteBatchCAS executes a batch operation and returns true if successful and
// an iterator (to scan additional rows if more than one conditional statement)
// was sent.
// Further scans on the interator must also remember to include
// the applied boolean as the first argument to *Iter.Scan
func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bool, iter *gocql.Iter, err error) {
return s.Session.ExecuteBatchCAS(batch.Batch, dest...)
}
// MapExecuteBatchCAS executes a batch operation much like ExecuteBatchCAS,
// however it accepts a map rather than a list of arguments for the initial
// scan.
func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) (applied bool, iter *gocql.Iter, err error) {
return s.Session.MapExecuteBatchCAS(batch.Batch, dest)
}

102
batchx_test.go Normal file
View File

@@ -0,0 +1,102 @@
// Copyright (C) 2017 ScyllaDB
// Use of this source code is governed by a ALv2-style
// license that can be found in the LICENSE file.
//go:build all || integration
// +build all integration
package gocqlx_test
import (
"testing"
"github.com/gocql/gocql"
"github.com/google/go-cmp/cmp"
"github.com/scylladb/gocqlx/v2"
"github.com/scylladb/gocqlx/v2/gocqlxtest"
"github.com/scylladb/gocqlx/v2/qb"
)
func TestBatch(t *testing.T) {
t.Parallel()
cluster := gocqlxtest.CreateCluster()
if err := gocqlxtest.CreateKeyspace(cluster, "batch_test"); err != nil {
t.Fatal("create keyspace:", err)
}
session, err := gocqlx.WrapSession(cluster.CreateSession())
if err != nil {
t.Fatal("create session:", err)
}
t.Cleanup(func() {
session.Close()
})
basicCreateAndPopulateKeyspace(t, session, "batch_test")
song := Song{
ID: mustParseUUID("60fc234a-8481-4343-93bb-72ecab404863"),
Title: "La Petite Tonkinoise",
Album: "Bye Bye Blackbird",
Artist: "Joséphine Baker",
Tags: []string{"jazz"},
Data: []byte("music"),
}
playlist := PlaylistItem{
ID: mustParseUUID("6a6255d9-680f-4cb5-b9a2-27cf4a810344"),
Title: "La Petite Tonkinoise",
Album: "Bye Bye Blackbird",
Artist: "Joséphine Baker",
SongID: mustParseUUID("60fc234a-8481-4343-93bb-72ecab404863"),
}
insertSong := qb.Insert("batch_test.songs").
Columns("id", "title", "album", "artist", "tags", "data").Query(session)
insertPlaylist := qb.Insert("batch_test.playlists").
Columns("id", "title", "album", "artist", "song_id").Query(session)
selectSong := qb.Select("batch_test.songs").Where(qb.Eq("id")).Query(session)
selectPlaylist := qb.Select("batch_test.playlists").Where(qb.Eq("id")).Query(session)
t.Run("batch inserts", func(t *testing.T) {
t.Parallel()
type batchQry struct {
qry *gocqlx.Queryx
arg interface{}
}
qrys := []batchQry{
{qry: insertSong, arg: song},
{qry: insertPlaylist, arg: playlist},
}
b := session.NewBatch(gocql.LoggedBatch)
for _, qry := range qrys {
if err := b.BindStruct(qry.qry, qry.arg); err != nil {
t.Fatal("BindStruct failed:", err)
}
}
if err := session.ExecuteBatch(b); err != nil {
t.Fatal("batch execution:", err)
}
// verify song was inserted
var gotSong Song
if err := selectSong.BindStruct(song).Get(&gotSong); err != nil {
t.Fatal("select song:", err)
}
if diff := cmp.Diff(gotSong, song); diff != "" {
t.Errorf("expected %v song, got %v, diff: %q", song, gotSong, diff)
}
// verify playlist item was inserted
var gotPlayList PlaylistItem
if err := selectPlaylist.BindStruct(playlist).Get(&gotPlayList); err != nil {
t.Fatal("select song:", err)
}
if diff := cmp.Diff(gotPlayList, playlist); diff != "" {
t.Errorf("expected %v playList, got %v, diff: %q", playlist, gotPlayList, diff)
}
})
}

View File

@@ -36,7 +36,7 @@ func TestExample(t *testing.T) {
session.ExecStmt(`DROP KEYSPACE examples`) session.ExecStmt(`DROP KEYSPACE examples`)
basicCreateAndPopulateKeyspace(t, session) basicCreateAndPopulateKeyspace(t, session, "examples")
createAndPopulateKeyspaceAllTypes(t, session) createAndPopulateKeyspaceAllTypes(t, session)
basicReadScyllaVersion(t, session) basicReadScyllaVersion(t, session)
@@ -52,56 +52,59 @@ func TestExample(t *testing.T) {
unsetEmptyValues(t, session) unsetEmptyValues(t, session)
} }
// This example shows how to use query builders and table models to build type Song struct {
// queries. It uses "BindStruct" function for parameter binding and "Select"
// function for loading data to a slice.
func basicCreateAndPopulateKeyspace(t *testing.T, session gocqlx.Session) {
err := session.ExecStmt(`CREATE KEYSPACE IF NOT EXISTS examples WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}`)
if err != nil {
t.Fatal("create keyspace:", err)
}
type Song struct {
ID gocql.UUID ID gocql.UUID
Title string Title string
Album string Album string
Artist string Artist string
Tags []string Tags []string
Data []byte Data []byte
} }
type PlaylistItem struct { type PlaylistItem struct {
ID gocql.UUID ID gocql.UUID
Title string Title string
Album string Album string
Artist string Artist string
SongID gocql.UUID SongID gocql.UUID
}
// This example shows how to use query builders and table models to build
// queries. It uses "BindStruct" function for parameter binding and "Select"
// function for loading data to a slice.
func basicCreateAndPopulateKeyspace(t *testing.T, session gocqlx.Session, keyspace string) {
err := session.ExecStmt(fmt.Sprintf(
`CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}`,
keyspace,
))
if err != nil {
t.Fatal("create keyspace:", err)
} }
err = session.ExecStmt(`CREATE TABLE IF NOT EXISTS examples.songs ( err = session.ExecStmt(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.songs (
id uuid PRIMARY KEY, id uuid PRIMARY KEY,
title text, title text,
album text, album text,
artist text, artist text,
tags set<text>, tags set<text>,
data blob)`) data blob)`, keyspace))
if err != nil { if err != nil {
t.Fatal("create table:", err) t.Fatal("create table:", err)
} }
err = session.ExecStmt(`CREATE TABLE IF NOT EXISTS examples.playlists ( err = session.ExecStmt(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.playlists (
id uuid, id uuid,
title text, title text,
album text, album text,
artist text, artist text,
song_id uuid, song_id uuid,
PRIMARY KEY (id, title, album, artist))`) PRIMARY KEY (id, title, album, artist))`, keyspace))
if err != nil { if err != nil {
t.Fatal("create table:", err) t.Fatal("create table:", err)
} }
playlistMetadata := table.Metadata{ playlistMetadata := table.Metadata{
Name: "examples.playlists", Name: fmt.Sprintf("%s.playlists", keyspace),
Columns: []string{"id", "title", "album", "artist", "song_id"}, Columns: []string{"id", "title", "album", "artist", "song_id"},
PartKey: []string{"id"}, PartKey: []string{"id"},
SortKey: []string{"title", "album", "artist", "song_id"}, SortKey: []string{"title", "album", "artist", "song_id"},
@@ -109,7 +112,7 @@ func basicCreateAndPopulateKeyspace(t *testing.T, session gocqlx.Session) {
playlistTable := table.New(playlistMetadata) playlistTable := table.New(playlistMetadata)
// Insert song using query builder. // Insert song using query builder.
insertSong := qb.Insert("examples.songs"). insertSong := qb.Insert(fmt.Sprintf("%s.songs", keyspace)).
Columns("id", "title", "album", "artist", "tags", "data").Query(session) Columns("id", "title", "album", "artist", "tags", "data").Query(session)
insertSong.BindStruct(Song{ insertSong.BindStruct(Song{
@@ -275,7 +278,7 @@ func createAndPopulateKeyspaceAllTypes(t *testing.T, session gocqlx.Session) {
insertCheckTypes.BindStruct(CheckTypesStruct{ insertCheckTypes.BindStruct(CheckTypesStruct{
AsciI: "test qscci", AsciI: "test qscci",
BigInt: 9223372036854775806, //MAXINT64 - 1, BigInt: 9223372036854775806, // MAXINT64 - 1,
BloB: []byte("this is blob test"), BloB: []byte("this is blob test"),
BooleaN: false, BooleaN: false,
DatE: date, DatE: date,