Add additional methods to Batch similar to what exists on Queryx

This commit is contained in:
Dmitry Kropachev
2024-06-20 00:38:01 -04:00
committed by Sylwia Szunejko
parent 207ba8723e
commit 2eee0b00f1
2 changed files with 160 additions and 38 deletions

View File

@@ -1,6 +1,8 @@
package gocqlx
import (
"fmt"
"github.com/gocql/gocql"
)
@@ -27,6 +29,38 @@ func (b *Batch) BindStruct(qry *Queryx, arg interface{}) error {
return nil
}
// Bind binds query parameters to values from args.
// If value cannot be found an error is reported.
func (b *Batch) Bind(qry *Queryx, args ...interface{}) error {
if len(qry.Names) != len(args) {
return fmt.Errorf("query requires %d arguments, but %d provided", len(qry.Names), len(args))
}
b.Query(qry.Statement(), args...)
return nil
}
// BindMap binds query named parameters to values from arg using a mapper.
// If value cannot be found an error is reported.
func (b *Batch) BindMap(qry *Queryx, arg map[string]interface{}) error {
args, err := qry.bindMapArgs(arg)
if err != nil {
return err
}
b.Query(qry.Statement(), args...)
return nil
}
// BindStructMap binds query named parameters to values from arg0 and arg1 using a mapper.
// If value cannot be found an error is reported.
func (b *Batch) BindStructMap(qry *Queryx, arg0 interface{}, arg1 map[string]interface{}) error {
args, err := qry.bindStructArgs(arg0, arg1)
if err != nil {
return err
}
b.Query(qry.Statement(), args...)
return nil
}
// ExecuteBatch executes a batch operation and returns nil if successful
// otherwise an error describing the failure.
func (s *Session) ExecuteBatch(batch *Batch) error {

View File

@@ -52,52 +52,140 @@ func TestBatch(t *testing.T) {
SongID: mustParseUUID("60fc234a-8481-4343-93bb-72ecab404863"),
}
insertSong := qb.Insert("batch_test.songs").
Columns("id", "title", "album", "artist", "tags", "data").Query(session)
insertPlaylist := qb.Insert("batch_test.playlists").
Columns("id", "title", "album", "artist", "song_id").Query(session)
selectSong := qb.Select("batch_test.songs").Where(qb.Eq("id")).Query(session)
selectPlaylist := qb.Select("batch_test.playlists").Where(qb.Eq("id")).Query(session)
t.Run("batch inserts", func(t *testing.T) {
t.Parallel()
type batchQry struct {
qry *gocqlx.Queryx
arg interface{}
tcases := []struct {
name string
methodSong func(*gocqlx.Batch, *gocqlx.Queryx, Song) error
methodPlaylist func(*gocqlx.Batch, *gocqlx.Queryx, PlaylistItem) error
}{
{
name: "BindStruct",
methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error {
return b.BindStruct(q, song)
},
methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error {
return b.BindStruct(q, playlist)
},
},
{
name: "BindMap",
methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error {
return b.BindMap(q, map[string]interface{}{
"id": song.ID,
"title": song.Title,
"album": song.Album,
"artist": song.Artist,
"tags": song.Tags,
"data": song.Data,
})
},
methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error {
return b.BindMap(q, map[string]interface{}{
"id": playlist.ID,
"title": playlist.Title,
"album": playlist.Album,
"artist": playlist.Artist,
"song_id": playlist.SongID,
})
},
},
{
name: "Bind",
methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error {
return b.Bind(q, song.ID, song.Title, song.Album, song.Artist, song.Tags, song.Data)
},
methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error {
return b.Bind(q, playlist.ID, playlist.Title, playlist.Album, playlist.Artist, playlist.SongID)
},
},
{
name: "BindStructMap",
methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error {
in := map[string]interface{}{
"title": song.Title,
"album": song.Album,
}
return b.BindStructMap(q, struct {
ID gocql.UUID
Artist string
Tags []string
Data []byte
}{
ID: song.ID,
Artist: song.Artist,
Tags: song.Tags,
Data: song.Data,
}, in)
},
methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error {
in := map[string]interface{}{
"title": playlist.Title,
"album": playlist.Album,
}
return b.BindStructMap(q, struct {
ID gocql.UUID
Artist string
SongID gocql.UUID
}{
ID: playlist.ID,
Artist: playlist.Artist,
SongID: playlist.SongID,
},
in,
)
},
},
}
for _, tcase := range tcases {
t.Run(tcase.name, func(t *testing.T) {
insertSong := qb.Insert("batch_test.songs").
Columns("id", "title", "album", "artist", "tags", "data").Query(session)
insertPlaylist := qb.Insert("batch_test.playlists").
Columns("id", "title", "album", "artist", "song_id").Query(session)
selectSong := qb.Select("batch_test.songs").Where(qb.Eq("id")).Query(session)
selectPlaylist := qb.Select("batch_test.playlists").Where(qb.Eq("id")).Query(session)
deleteSong := qb.Delete("batch_test.songs").Where(qb.Eq("id")).Query(session)
deletePlaylist := qb.Delete("batch_test.playlists").Where(qb.Eq("id")).Query(session)
qrys := []batchQry{
{qry: insertSong, arg: song},
{qry: insertPlaylist, arg: playlist},
}
b := session.NewBatch(gocql.LoggedBatch)
b := session.NewBatch(gocql.LoggedBatch)
for _, qry := range qrys {
if err := b.BindStruct(qry.qry, qry.arg); err != nil {
t.Fatal("BindStruct failed:", err)
}
}
if err := session.ExecuteBatch(b); err != nil {
t.Fatal("batch execution:", err)
}
if err = tcase.methodSong(b, insertSong, song); err != nil {
t.Fatal("insert song:", err)
}
if err = tcase.methodPlaylist(b, insertPlaylist, playlist); err != nil {
t.Fatal("insert playList:", err)
}
// verify song was inserted
var gotSong Song
if err := selectSong.BindStruct(song).Get(&gotSong); err != nil {
t.Fatal("select song:", err)
}
if diff := cmp.Diff(gotSong, song); diff != "" {
t.Errorf("expected %v song, got %v, diff: %q", song, gotSong, diff)
}
if err := session.ExecuteBatch(b); err != nil {
t.Fatal("batch execution:", err)
}
// verify playlist item was inserted
var gotPlayList PlaylistItem
if err := selectPlaylist.BindStruct(playlist).Get(&gotPlayList); err != nil {
t.Fatal("select song:", err)
}
if diff := cmp.Diff(gotPlayList, playlist); diff != "" {
t.Errorf("expected %v playList, got %v, diff: %q", playlist, gotPlayList, diff)
// verify song was inserted
var gotSong Song
if err := selectSong.BindStruct(song).Get(&gotSong); err != nil {
t.Fatal("select song:", err)
}
if diff := cmp.Diff(gotSong, song); diff != "" {
t.Errorf("expected %v song, got %v, diff: %q", song, gotSong, diff)
}
// verify playlist item was inserted
var gotPlayList PlaylistItem
if err := selectPlaylist.BindStruct(playlist).Get(&gotPlayList); err != nil {
t.Fatal("select playList:", err)
}
if diff := cmp.Diff(gotPlayList, playlist); diff != "" {
t.Errorf("expected %v playList, got %v, diff: %q", playlist, gotPlayList, diff)
}
if err = deletePlaylist.BindStruct(playlist).Exec(); err != nil {
t.Error("delete playlist:", err)
}
if err = deleteSong.BindStruct(song).Exec(); err != nil {
t.Error("delete song:", err)
}
})
}
})
}