From efee4798d6e1672ca10273395c4be22e6e5e87a4 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Wed, 11 Jun 2025 06:16:31 -0400 Subject: [PATCH] Add missing batch API (#325) Some of gocql.Batch API are not present in the gocqlx.Batch. We need to have them implemented and have a test to make sure no API has forgotten. --- batchx.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++ batchx_test.go | 21 ++++++++++++++ queryx_test.go | 2 +- 3 files changed, 101 insertions(+), 1 deletion(-) diff --git a/batchx.go b/batchx.go index f17539b..f9c8bd8 100644 --- a/batchx.go +++ b/batchx.go @@ -1,6 +1,7 @@ package gocqlx import ( + "context" "fmt" "github.com/gocql/gocql" @@ -70,6 +71,84 @@ func (b *Batch) BindStructMap(qry *Queryx, arg0 interface{}, arg1 map[string]int return nil } +// DefaultTimestamp will enable the with default timestamp flag on the query. +// If enabled, this will replace the server side assigned +// timestamp as default timestamp. Note that a timestamp in the query itself +// will still override this timestamp. This is entirely optional. +// +// Only available on protocol >= 3 +func (b *Batch) DefaultTimestamp(enable bool) *Batch { + b.Batch.DefaultTimestamp(enable) + return b +} + +// Observer enables batch-level observer on this batch. +// The provided observer will be called every time this batched query is executed. +func (b *Batch) Observer(observer gocql.BatchObserver) *Batch { + b.Batch.Observer(observer) + return b +} + +// RetryPolicy sets the retry policy to use when executing the batch operation +func (b *Batch) RetryPolicy(policy gocql.RetryPolicy) *Batch { + b.Batch.RetryPolicy(policy) + return b +} + +// SerialConsistency sets the consistency level for the +// serial phase of conditional updates. That consistency can only be +// either SERIAL or LOCAL_SERIAL and if not present, it defaults to +// SERIAL. This option will be ignored for anything else that a +// conditional update/insert. +// +// Only available for protocol 3 and above +func (b *Batch) SerialConsistency(cons gocql.Consistency) *Batch { + b.Batch.SerialConsistency(cons) + return b +} + +// SpeculativeExecutionPolicy sets the speculative execution policy to use when executing the batch operation +func (b *Batch) SpeculativeExecutionPolicy(policy gocql.SpeculativeExecutionPolicy) *Batch { + b.Batch.SpeculativeExecutionPolicy(policy) + return b +} + +// Trace enables tracing of this batch. Look at the documentation of the +// gocql.Tracer interface to learn more about tracing. +func (b *Batch) Trace(trace gocql.Tracer) *Batch { + b.Batch.Trace(trace) + return b +} + +// WithContext returns a shallow copy of b with its context +// set to ctx. +// +// The provided context controls the entire lifetime of executing a +// query, queries will be canceled and return once the context is +// canceled. +func (b *Batch) WithContext(ctx context.Context) *Batch { + return &Batch{ + Batch: b.Batch.WithContext(ctx), + } +} + +// WithTimestamp will enable the with default timestamp flag on the query +// like DefaultTimestamp does. But also allows to define value for timestamp. +// It works the same way as USING TIMESTAMP in the query itself, but +// should not break prepared query optimization. +// +// Only available on protocol >= 3 +func (b *Batch) WithTimestamp(timestamp int64) *Batch { + b.Batch.WithTimestamp(timestamp) + return b +} + +// Query adds the query to the batch operation +func (b *Batch) Query(stmt string, args ...interface{}) *Batch { + b.Batch.Query(stmt, args...) + return b +} + // ExecuteBatch executes a batch operation and returns nil if successful // otherwise an error describing the failure. func (s *Session) ExecuteBatch(batch *Batch) error { diff --git a/batchx_test.go b/batchx_test.go index f1e3b00..f084fe0 100644 --- a/batchx_test.go +++ b/batchx_test.go @@ -8,6 +8,7 @@ package gocqlx_test import ( + "reflect" "testing" "github.com/gocql/gocql" @@ -189,3 +190,23 @@ func TestBatch(t *testing.T) { } }) } + +func TestBatchAllWrapped(t *testing.T) { + var ( + gocqlType = reflect.TypeOf((*gocql.Batch)(nil)) + gocqlxType = reflect.TypeOf((*gocqlx.Batch)(nil)) + ) + + for i := 0; i < gocqlType.NumMethod(); i++ { + m, ok := gocqlxType.MethodByName(gocqlType.Method(i).Name) + if !ok { + t.Fatalf("Batch missing method %s", gocqlType.Method(i).Name) + } + + for j := 0; j < m.Type.NumOut(); j++ { + if m.Type.Out(j) == gocqlType { + t.Errorf("Batch method %s not wrapped", m.Name) + } + } + } +} diff --git a/queryx_test.go b/queryx_test.go index dfd3a2d..46418ff 100644 --- a/queryx_test.go +++ b/queryx_test.go @@ -212,7 +212,7 @@ func TestQueryxBindMap(t *testing.T) { }) } -func TestQyeryxAllWrapped(t *testing.T) { +func TestQueryxAllWrapped(t *testing.T) { var ( gocqlQueryPtr = reflect.TypeOf((*gocql.Query)(nil)) queryxPtr = reflect.TypeOf((*Queryx)(nil))